diff --git a/generators/python/core_utilities/shared/pydantic_utilities.py b/generators/python/core_utilities/shared/pydantic_utilities.py index 0cc5665a880b..c99b57c00490 100644 --- a/generators/python/core_utilities/shared/pydantic_utilities.py +++ b/generators/python/core_utilities/shared/pydantic_utilities.py @@ -133,111 +133,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -246,66 +156,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/generators/python/core_utilities/shared/with_pydantic_v1_on_v2/with_aliases/pydantic_utilities.py b/generators/python/core_utilities/shared/with_pydantic_v1_on_v2/with_aliases/pydantic_utilities.py index 775fadca8dca..fe0bae1043c0 100644 --- a/generators/python/core_utilities/shared/with_pydantic_v1_on_v2/with_aliases/pydantic_utilities.py +++ b/generators/python/core_utilities/shared/with_pydantic_v1_on_v2/with_aliases/pydantic_utilities.py @@ -1,15 +1,13 @@ # nopycln: file import datetime as dt -import inspect import json import logging import warnings from collections import defaultdict from dataclasses import asdict -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Tuple, Type, TypeVar, Union, cast import pydantic -import typing_extensions from .datetime_utils import serialize_datetime if TYPE_CHECKING: @@ -52,93 +50,21 @@ def parse_date(value: Any) -> dt.date: AnyCallable = Callable[..., Any] -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation_v1(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic v1 model.""" - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.v1.BaseModel)): - continue - - disc_annotation = _get_field_annotation_v1(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -147,66 +73,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation_v1(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) def parse_obj_as(type_: Type[T], object_: Any) -> T: diff --git a/generators/python/sdk/changes/5.14.4/fix-sse-discrimination.yml b/generators/python/sdk/changes/5.14.4/fix-sse-discrimination.yml new file mode 100644 index 000000000000..acebd7176f48 --- /dev/null +++ b/generators/python/sdk/changes/5.14.4/fix-sse-discrimination.yml @@ -0,0 +1,13 @@ +# yaml-language-server: $schema=../../../../../fern-changes-yml.schema.json + +- summary: | + Fix SSE union discrimination for both data-level and protocol-level contexts. + + Data-level: Simplify `parse_sse_obj` to always parse the SSE `data` field as + JSON instead of using runtime heuristics. This fixes incorrect routing when the + discriminant field name (e.g., `event`) collides with an SSE envelope field. + + Protocol-level: Generate an if/elif dispatch chain at code-generation time that + routes on `_sse.event`, parsing each variant's data payload into its concrete + type. No runtime heuristic — the generator decides the code path statically. + type: fix diff --git a/generators/python/sdk/versions.yml b/generators/python/sdk/versions.yml index cd29cf81fe9d..b0ddc9e4c3c8 100644 --- a/generators/python/sdk/versions.yml +++ b/generators/python/sdk/versions.yml @@ -1,4 +1,19 @@ # yaml-language-server: $schema=../../../fern-versions-yml.schema.json +- version: 5.14.4 + changelogEntry: + - summary: | + Fix SSE union discrimination for both data-level and protocol-level contexts. + + Data-level: Simplify `parse_sse_obj` to always parse the SSE `data` field as + JSON instead of using runtime heuristics. This fixes incorrect routing when the + discriminant field name (e.g., `event`) collides with an SSE envelope field. + + Protocol-level: Generate an if/elif dispatch chain at code-generation time that + routes on `_sse.event`, parsing each variant's data payload into its concrete + type. No runtime heuristic — the generator decides the code path statically. + type: fix + createdAt: "2026-05-28" + irVersion: 67 - version: 5.14.3 changelogEntry: - summary: | @@ -751,6 +766,7 @@ createdAt: "2026-03-27" irVersion: 65 + - version: 5.0.8 changelogEntry: - summary: | diff --git a/generators/python/src/fern_python/generators/sdk/client_generator/endpoint_response_code_writer.py b/generators/python/src/fern_python/generators/sdk/client_generator/endpoint_response_code_writer.py index 0f39f7d0c8b0..381d2295894a 100644 --- a/generators/python/src/fern_python/generators/sdk/client_generator/endpoint_response_code_writer.py +++ b/generators/python/src/fern_python/generators/sdk/client_generator/endpoint_response_code_writer.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Sequence, Tuple from ..context.sdk_generator_context import SdkGeneratorContext from fern_python.codegen import AST @@ -108,6 +108,12 @@ def _handle_success_stream(self, *, writer: AST.NodeWriter, stream_response: ir_ stream_response_union = stream_response.get_as_union() if stream_response_union.type == "sse": + protocol_info = self._get_protocol_discriminated_union_info(stream_response_union.payload) + sse_for_body = self._build_sse_for_body( + stream_response=stream_response, + stream_response_union=stream_response_union, + protocol_info=protocol_info, + ) iter_func_body.extend( [ AST.VariableDeclaration( @@ -151,85 +157,8 @@ def _handle_success_stream(self, *, writer: AST.NodeWriter, stream_response: ir_ ], else_code=None, ), - AST.TryStatement( - body=[ - AST.YieldStatement( - self._context.core_utilities.get_construct_sse( - self._get_streaming_response_data_type(stream_response), - AST.Expression(f"{EndpointResponseCodeWriter.SSE_VARIABLE}"), - ), - ), - ], - handlers=[ - AST.ExceptHandler( - body=[ - AST.Expression( - AST.FunctionInvocation( - function_definition=AST.Reference( - qualified_name_excluding_import=(), - import_=AST.ReferenceImport( - module=AST.Module.built_in(("logging",)), - named_import="warning", - ), - ), - args=[ - AST.Expression( - f'f"Skipping SSE event with invalid JSON: {{e}}, sse: {{{EndpointResponseCodeWriter.SSE_VARIABLE}!r}}"' - ) - ], - ) - ), - ], - exception_type="JSONDecodeError", - name="e", - ), - AST.ExceptHandler( - body=[ - AST.Expression( - AST.FunctionInvocation( - function_definition=AST.Reference( - qualified_name_excluding_import=(), - import_=AST.ReferenceImport( - module=AST.Module.built_in(("logging",)), - named_import="warning", - ), - ), - args=[ - AST.Expression( - f'f"Skipping SSE event due to model construction error: {{type(e).__name__}}: {{e}}, sse: {{{EndpointResponseCodeWriter.SSE_VARIABLE}!r}}"' - ) - ], - ) - ), - ], - exception_type="(TypeError, ValueError, KeyError, AttributeError)", - name="e", - ), - AST.ExceptHandler( - body=[ - AST.Expression( - AST.FunctionInvocation( - function_definition=AST.Reference( - qualified_name_excluding_import=(), - import_=AST.ReferenceImport( - module=AST.Module.built_in(("logging",)), - named_import="error", - ), - ), - args=[ - AST.Expression( - f'f"Unexpected error processing SSE event: {{type(e).__name__}}: {{e}}, sse: {{{EndpointResponseCodeWriter.SSE_VARIABLE}!r}}"' - ) - ], - ) - ), - ], - exception_type="Exception", - name="e", - ), - ], - ), - ], + ] + + sse_for_body, is_async=self._is_async, ), ] @@ -871,3 +800,196 @@ def _get_streaming_response_data_type(self, streaming_response: ir_types.Streami if union.type == "text": return AST.TypeHint.str_() raise RuntimeError(f"{union.type} streaming response is unsupported") + + def _get_protocol_discriminated_union_info( + self, payload: ir_types.TypeReference + ) -> Optional[Sequence[Tuple[str, ir_types.SingleUnionType]]]: + """Check if payload is a protocol-discriminated union and return variant info. + + Returns a list of (wire_value, SingleUnionType) tuples if the payload is + a named union type with discriminator_context == "protocol", else None. + """ + payload_union = payload.get_as_union() + if payload_union.type != "named": + return None + type_declaration = self._context.pydantic_generator_context.get_declaration_for_type_id(payload_union.type_id) + shape_union = type_declaration.shape.get_as_union() + if shape_union.type != "union": + return None + union_decl: ir_types.UnionTypeDeclaration = shape_union + if union_decl.discriminator_context is None or union_decl.discriminator_context.value != "protocol": + return None + return [(get_wire_value(variant.discriminant_value), variant) for variant in union_decl.types] + + def _get_variant_type_hint(self, variant: ir_types.SingleUnionType) -> AST.TypeHint: + """Get the type hint for a single union variant's data shape.""" + shape_union = variant.shape.get_as_union() + if shape_union.properties_type == "samePropertiesAsObject": + named_type = ir_types.NamedType( + type_id=shape_union.type_id, + fern_filepath=shape_union.fern_filepath, + name=shape_union.name, + ) + return self._context.pydantic_generator_context.get_type_hint_for_type_reference( + ir_types.TypeReference.factory.named(named_type) + ) + if shape_union.properties_type == "singleProperty": + return self._context.pydantic_generator_context.get_type_hint_for_type_reference(shape_union.type) + # noProperties — yield the raw parsed data as the overall union type + return AST.TypeHint.any() + + def _build_sse_for_body( + self, + *, + stream_response: ir_types.StreamingResponse, + stream_response_union: Any, + protocol_info: Optional[Sequence[Tuple[str, ir_types.SingleUnionType]]], + ) -> list[AST.AstNode]: + """Build the list of AST nodes inside the SSE for-loop body. + + For data-level discrimination (protocol_info is None) this uses + parse_sse_obj. For protocol-level discrimination it emits an + if/elif chain that dispatches on _sse.event. + """ + if protocol_info is None: + return self._build_data_level_sse_body(stream_response) + return self._build_protocol_level_sse_body(protocol_info) + + def _build_data_level_sse_body( + self, + stream_response: ir_types.StreamingResponse, + ) -> list[AST.AstNode]: + """Generate a try/yield block using parse_sse_obj for data-level discrimination.""" + return [ + AST.TryStatement( + body=[ + AST.YieldStatement( + self._context.core_utilities.get_construct_sse( + self._get_streaming_response_data_type(stream_response), + AST.Expression(f"{EndpointResponseCodeWriter.SSE_VARIABLE}"), + ), + ), + ], + handlers=[ + AST.ExceptHandler( + body=[ + AST.Expression( + AST.FunctionInvocation( + function_definition=AST.Reference( + qualified_name_excluding_import=(), + import_=AST.ReferenceImport( + module=AST.Module.built_in(("logging",)), + named_import="warning", + ), + ), + args=[ + AST.Expression( + f'f"Skipping SSE event with invalid JSON: {{e}}, sse: {{{EndpointResponseCodeWriter.SSE_VARIABLE}!r}}"' + ) + ], + ) + ), + ], + exception_type="JSONDecodeError", + name="e", + ), + AST.ExceptHandler( + body=[ + AST.Expression( + AST.FunctionInvocation( + function_definition=AST.Reference( + qualified_name_excluding_import=(), + import_=AST.ReferenceImport( + module=AST.Module.built_in(("logging",)), + named_import="warning", + ), + ), + args=[ + AST.Expression( + f'f"Skipping SSE event due to model construction error: {{type(e).__name__}}: {{e}}, sse: {{{EndpointResponseCodeWriter.SSE_VARIABLE}!r}}"' + ) + ], + ) + ), + ], + exception_type="(TypeError, ValueError, KeyError, AttributeError)", + name="e", + ), + AST.ExceptHandler( + body=[ + AST.Expression( + AST.FunctionInvocation( + function_definition=AST.Reference( + qualified_name_excluding_import=(), + import_=AST.ReferenceImport( + module=AST.Module.built_in(("logging",)), + named_import="error", + ), + ), + args=[ + AST.Expression( + f'f"Unexpected error processing SSE event: {{type(e).__name__}}: {{e}}, sse: {{{EndpointResponseCodeWriter.SSE_VARIABLE}!r}}"' + ) + ], + ) + ), + ], + exception_type="Exception", + name="e", + ), + ], + ), + ] + + def _build_protocol_level_sse_body( + self, + protocol_info: Sequence[Tuple[str, ir_types.SingleUnionType]], + ) -> list[AST.AstNode]: + """Generate an if/elif chain dispatching on _sse.event for protocol-level discrimination.""" + conditions: list[AST.IfConditionLeaf] = [] + for wire_value, variant in protocol_info: + variant_type_hint = self._get_variant_type_hint(variant) + yield_expr = self._context.core_utilities.get_construct( + variant_type_hint, + AST.Expression(Json.loads(AST.Expression(f"{EndpointResponseCodeWriter.SSE_VARIABLE}.data"))), + ) + conditions.append( + AST.IfConditionLeaf( + condition=AST.Expression(f"{EndpointResponseCodeWriter.SSE_VARIABLE}.event == {repr(wire_value)}"), + code=[ + AST.TryStatement( + body=[AST.YieldStatement(yield_expr)], + handlers=[ + AST.ExceptHandler( + body=[ + AST.Expression( + AST.FunctionInvocation( + function_definition=AST.Reference( + qualified_name_excluding_import=(), + import_=AST.ReferenceImport( + module=AST.Module.built_in(("logging",)), + named_import="warning", + ), + ), + args=[ + AST.Expression( + f'f"Failed to parse SSE event {repr(wire_value)}: {{e}}, sse: {{{EndpointResponseCodeWriter.SSE_VARIABLE}!r}}"' + ) + ], + ) + ), + ], + exception_type="Exception", + name="e", + ), + ], + ), + ], + ) + ) + return [ + AST.ConditionalTree( + conditions=conditions, + else_code=None, + ), + ] diff --git a/seed/python-sdk/accept-header/poetry.lock b/seed/python-sdk/accept-header/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/accept-header/poetry.lock +++ b/seed/python-sdk/accept-header/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/accept-header/src/seed/core/pydantic_utilities.py b/seed/python-sdk/accept-header/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/accept-header/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/accept-header/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/alias-extends/no-custom-config/poetry.lock b/seed/python-sdk/alias-extends/no-custom-config/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/alias-extends/no-custom-config/poetry.lock +++ b/seed/python-sdk/alias-extends/no-custom-config/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/alias-extends/no-custom-config/src/seed/core/pydantic_utilities.py b/seed/python-sdk/alias-extends/no-custom-config/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/alias-extends/no-custom-config/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/alias-extends/no-custom-config/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/alias-extends/no-inheritance-for-extended-models/poetry.lock b/seed/python-sdk/alias-extends/no-inheritance-for-extended-models/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/alias-extends/no-inheritance-for-extended-models/poetry.lock +++ b/seed/python-sdk/alias-extends/no-inheritance-for-extended-models/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/alias-extends/no-inheritance-for-extended-models/src/seed/core/pydantic_utilities.py b/seed/python-sdk/alias-extends/no-inheritance-for-extended-models/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/alias-extends/no-inheritance-for-extended-models/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/alias-extends/no-inheritance-for-extended-models/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/alias/poetry.lock b/seed/python-sdk/alias/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/alias/poetry.lock +++ b/seed/python-sdk/alias/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/alias/src/seed/core/pydantic_utilities.py b/seed/python-sdk/alias/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/alias/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/alias/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/allof-inline/no-custom-config/poetry.lock b/seed/python-sdk/allof-inline/no-custom-config/poetry.lock index fb3a0de06c97..45d91dca8189 100644 --- a/seed/python-sdk/allof-inline/no-custom-config/poetry.lock +++ b/seed/python-sdk/allof-inline/no-custom-config/poetry.lock @@ -1269,23 +1269,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/allof-inline/no-custom-config/src/seed/core/pydantic_utilities.py b/seed/python-sdk/allof-inline/no-custom-config/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/allof-inline/no-custom-config/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/allof-inline/no-custom-config/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/allof/no-custom-config/poetry.lock b/seed/python-sdk/allof/no-custom-config/poetry.lock index fb3a0de06c97..45d91dca8189 100644 --- a/seed/python-sdk/allof/no-custom-config/poetry.lock +++ b/seed/python-sdk/allof/no-custom-config/poetry.lock @@ -1269,23 +1269,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/allof/no-custom-config/src/seed/core/pydantic_utilities.py b/seed/python-sdk/allof/no-custom-config/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/allof/no-custom-config/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/allof/no-custom-config/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/any-auth/poetry.lock b/seed/python-sdk/any-auth/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/any-auth/poetry.lock +++ b/seed/python-sdk/any-auth/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/any-auth/src/seed/core/pydantic_utilities.py b/seed/python-sdk/any-auth/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/any-auth/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/any-auth/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/api-wide-base-path-with-default/poetry.lock b/seed/python-sdk/api-wide-base-path-with-default/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/api-wide-base-path-with-default/poetry.lock +++ b/seed/python-sdk/api-wide-base-path-with-default/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/api-wide-base-path-with-default/src/seed/core/pydantic_utilities.py b/seed/python-sdk/api-wide-base-path-with-default/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/api-wide-base-path-with-default/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/api-wide-base-path-with-default/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/api-wide-base-path/poetry.lock b/seed/python-sdk/api-wide-base-path/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/api-wide-base-path/poetry.lock +++ b/seed/python-sdk/api-wide-base-path/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/api-wide-base-path/src/seed/core/pydantic_utilities.py b/seed/python-sdk/api-wide-base-path/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/api-wide-base-path/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/api-wide-base-path/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/audiences/poetry.lock b/seed/python-sdk/audiences/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/audiences/poetry.lock +++ b/seed/python-sdk/audiences/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/audiences/src/seed/core/pydantic_utilities.py b/seed/python-sdk/audiences/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/audiences/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/audiences/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/basic-auth-environment-variables/poetry.lock b/seed/python-sdk/basic-auth-environment-variables/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/basic-auth-environment-variables/poetry.lock +++ b/seed/python-sdk/basic-auth-environment-variables/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/basic-auth-environment-variables/src/seed/core/pydantic_utilities.py b/seed/python-sdk/basic-auth-environment-variables/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/basic-auth-environment-variables/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/basic-auth-environment-variables/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/basic-auth-pw-omitted/with-wire-tests/poetry.lock b/seed/python-sdk/basic-auth-pw-omitted/with-wire-tests/poetry.lock index fb3a0de06c97..45d91dca8189 100644 --- a/seed/python-sdk/basic-auth-pw-omitted/with-wire-tests/poetry.lock +++ b/seed/python-sdk/basic-auth-pw-omitted/with-wire-tests/poetry.lock @@ -1269,23 +1269,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/basic-auth-pw-omitted/with-wire-tests/src/seed/core/pydantic_utilities.py b/seed/python-sdk/basic-auth-pw-omitted/with-wire-tests/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/basic-auth-pw-omitted/with-wire-tests/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/basic-auth-pw-omitted/with-wire-tests/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/basic-auth/poetry.lock b/seed/python-sdk/basic-auth/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/basic-auth/poetry.lock +++ b/seed/python-sdk/basic-auth/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/basic-auth/src/seed/core/pydantic_utilities.py b/seed/python-sdk/basic-auth/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/basic-auth/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/basic-auth/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/bearer-token-environment-variable/poetry.lock b/seed/python-sdk/bearer-token-environment-variable/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/bearer-token-environment-variable/poetry.lock +++ b/seed/python-sdk/bearer-token-environment-variable/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/bearer-token-environment-variable/src/seed/core/pydantic_utilities.py b/seed/python-sdk/bearer-token-environment-variable/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/bearer-token-environment-variable/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/bearer-token-environment-variable/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/bytes-download/poetry.lock b/seed/python-sdk/bytes-download/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/bytes-download/poetry.lock +++ b/seed/python-sdk/bytes-download/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/bytes-download/src/seed/core/pydantic_utilities.py b/seed/python-sdk/bytes-download/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/bytes-download/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/bytes-download/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/bytes-upload/poetry.lock b/seed/python-sdk/bytes-upload/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/bytes-upload/poetry.lock +++ b/seed/python-sdk/bytes-upload/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/bytes-upload/src/seed/core/pydantic_utilities.py b/seed/python-sdk/bytes-upload/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/bytes-upload/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/bytes-upload/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/circular-references-advanced/no-inheritance-for-extended-models/poetry.lock b/seed/python-sdk/circular-references-advanced/no-inheritance-for-extended-models/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/circular-references-advanced/no-inheritance-for-extended-models/poetry.lock +++ b/seed/python-sdk/circular-references-advanced/no-inheritance-for-extended-models/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/circular-references-advanced/no-inheritance-for-extended-models/src/seed/core/pydantic_utilities.py b/seed/python-sdk/circular-references-advanced/no-inheritance-for-extended-models/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/circular-references-advanced/no-inheritance-for-extended-models/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/circular-references-advanced/no-inheritance-for-extended-models/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/circular-references-extends/no-custom-config/poetry.lock b/seed/python-sdk/circular-references-extends/no-custom-config/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/circular-references-extends/no-custom-config/poetry.lock +++ b/seed/python-sdk/circular-references-extends/no-custom-config/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/circular-references-extends/no-custom-config/src/seed/core/pydantic_utilities.py b/seed/python-sdk/circular-references-extends/no-custom-config/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/circular-references-extends/no-custom-config/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/circular-references-extends/no-custom-config/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/circular-references/no-custom-config/poetry.lock b/seed/python-sdk/circular-references/no-custom-config/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/circular-references/no-custom-config/poetry.lock +++ b/seed/python-sdk/circular-references/no-custom-config/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/circular-references/no-custom-config/src/seed/core/pydantic_utilities.py b/seed/python-sdk/circular-references/no-custom-config/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/circular-references/no-custom-config/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/circular-references/no-custom-config/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/circular-references/no-inheritance-for-extended-models/poetry.lock b/seed/python-sdk/circular-references/no-inheritance-for-extended-models/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/circular-references/no-inheritance-for-extended-models/poetry.lock +++ b/seed/python-sdk/circular-references/no-inheritance-for-extended-models/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/circular-references/no-inheritance-for-extended-models/src/seed/core/pydantic_utilities.py b/seed/python-sdk/circular-references/no-inheritance-for-extended-models/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/circular-references/no-inheritance-for-extended-models/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/circular-references/no-inheritance-for-extended-models/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/cli-multi-spec-namespaced/poetry.lock b/seed/python-sdk/cli-multi-spec-namespaced/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/cli-multi-spec-namespaced/poetry.lock +++ b/seed/python-sdk/cli-multi-spec-namespaced/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/cli-multi-spec-namespaced/src/seed/core/pydantic_utilities.py b/seed/python-sdk/cli-multi-spec-namespaced/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/cli-multi-spec-namespaced/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/cli-multi-spec-namespaced/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/cli-multi-spec/poetry.lock b/seed/python-sdk/cli-multi-spec/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/cli-multi-spec/poetry.lock +++ b/seed/python-sdk/cli-multi-spec/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/cli-multi-spec/src/seed/core/pydantic_utilities.py b/seed/python-sdk/cli-multi-spec/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/cli-multi-spec/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/cli-multi-spec/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/client-side-params/poetry.lock b/seed/python-sdk/client-side-params/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/client-side-params/poetry.lock +++ b/seed/python-sdk/client-side-params/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/client-side-params/src/seed/core/pydantic_utilities.py b/seed/python-sdk/client-side-params/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/client-side-params/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/client-side-params/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/content-type/poetry.lock b/seed/python-sdk/content-type/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/content-type/poetry.lock +++ b/seed/python-sdk/content-type/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/content-type/src/seed/core/pydantic_utilities.py b/seed/python-sdk/content-type/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/content-type/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/content-type/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/cross-package-type-names/poetry.lock b/seed/python-sdk/cross-package-type-names/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/cross-package-type-names/poetry.lock +++ b/seed/python-sdk/cross-package-type-names/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/cross-package-type-names/src/seed/core/pydantic_utilities.py b/seed/python-sdk/cross-package-type-names/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/cross-package-type-names/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/cross-package-type-names/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/dollar-string-examples/poetry.lock b/seed/python-sdk/dollar-string-examples/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/dollar-string-examples/poetry.lock +++ b/seed/python-sdk/dollar-string-examples/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/dollar-string-examples/src/seed/core/pydantic_utilities.py b/seed/python-sdk/dollar-string-examples/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/dollar-string-examples/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/dollar-string-examples/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/empty-clients/poetry.lock b/seed/python-sdk/empty-clients/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/empty-clients/poetry.lock +++ b/seed/python-sdk/empty-clients/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/empty-clients/src/seed/core/pydantic_utilities.py b/seed/python-sdk/empty-clients/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/empty-clients/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/empty-clients/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/endpoint-security-auth/poetry.lock b/seed/python-sdk/endpoint-security-auth/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/endpoint-security-auth/poetry.lock +++ b/seed/python-sdk/endpoint-security-auth/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/endpoint-security-auth/src/seed/core/pydantic_utilities.py b/seed/python-sdk/endpoint-security-auth/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/endpoint-security-auth/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/endpoint-security-auth/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/enum/no-custom-config/poetry.lock b/seed/python-sdk/enum/no-custom-config/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/enum/no-custom-config/poetry.lock +++ b/seed/python-sdk/enum/no-custom-config/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/enum/no-custom-config/src/seed/core/pydantic_utilities.py b/seed/python-sdk/enum/no-custom-config/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/enum/no-custom-config/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/enum/no-custom-config/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/enum/real-enum-forward-compat/poetry.lock b/seed/python-sdk/enum/real-enum-forward-compat/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/enum/real-enum-forward-compat/poetry.lock +++ b/seed/python-sdk/enum/real-enum-forward-compat/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/enum/real-enum-forward-compat/src/seed/core/pydantic_utilities.py b/seed/python-sdk/enum/real-enum-forward-compat/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/enum/real-enum-forward-compat/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/enum/real-enum-forward-compat/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/enum/real-enum/poetry.lock b/seed/python-sdk/enum/real-enum/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/enum/real-enum/poetry.lock +++ b/seed/python-sdk/enum/real-enum/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/enum/real-enum/src/seed/core/pydantic_utilities.py b/seed/python-sdk/enum/real-enum/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/enum/real-enum/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/enum/real-enum/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/enum/strenum/poetry.lock b/seed/python-sdk/enum/strenum/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/enum/strenum/poetry.lock +++ b/seed/python-sdk/enum/strenum/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/enum/strenum/src/seed/core/pydantic_utilities.py b/seed/python-sdk/enum/strenum/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/enum/strenum/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/enum/strenum/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/error-property/poetry.lock b/seed/python-sdk/error-property/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/error-property/poetry.lock +++ b/seed/python-sdk/error-property/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/error-property/src/seed/core/pydantic_utilities.py b/seed/python-sdk/error-property/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/error-property/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/error-property/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/errors/poetry.lock b/seed/python-sdk/errors/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/errors/poetry.lock +++ b/seed/python-sdk/errors/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/errors/src/seed/core/pydantic_utilities.py b/seed/python-sdk/errors/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/errors/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/errors/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/examples/additional_init_exports_with_duplicates/poetry.lock b/seed/python-sdk/examples/additional_init_exports_with_duplicates/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/examples/additional_init_exports_with_duplicates/poetry.lock +++ b/seed/python-sdk/examples/additional_init_exports_with_duplicates/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/examples/additional_init_exports_with_duplicates/src/seed/core/pydantic_utilities.py b/seed/python-sdk/examples/additional_init_exports_with_duplicates/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/examples/additional_init_exports_with_duplicates/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/examples/additional_init_exports_with_duplicates/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/examples/client-filename/poetry.lock b/seed/python-sdk/examples/client-filename/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/examples/client-filename/poetry.lock +++ b/seed/python-sdk/examples/client-filename/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/examples/client-filename/src/seed/core/pydantic_utilities.py b/seed/python-sdk/examples/client-filename/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/examples/client-filename/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/examples/client-filename/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/examples/legacy-wire-tests/poetry.lock b/seed/python-sdk/examples/legacy-wire-tests/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/examples/legacy-wire-tests/poetry.lock +++ b/seed/python-sdk/examples/legacy-wire-tests/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/examples/legacy-wire-tests/src/seed/core/pydantic_utilities.py b/seed/python-sdk/examples/legacy-wire-tests/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/examples/legacy-wire-tests/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/examples/legacy-wire-tests/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/examples/no-custom-config/poetry.lock b/seed/python-sdk/examples/no-custom-config/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/examples/no-custom-config/poetry.lock +++ b/seed/python-sdk/examples/no-custom-config/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/examples/no-custom-config/src/seed/core/pydantic_utilities.py b/seed/python-sdk/examples/no-custom-config/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/examples/no-custom-config/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/examples/no-custom-config/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/examples/omit-fern-headers/poetry.lock b/seed/python-sdk/examples/omit-fern-headers/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/examples/omit-fern-headers/poetry.lock +++ b/seed/python-sdk/examples/omit-fern-headers/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/examples/omit-fern-headers/src/seed/core/pydantic_utilities.py b/seed/python-sdk/examples/omit-fern-headers/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/examples/omit-fern-headers/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/examples/omit-fern-headers/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/examples/readme/poetry.lock b/seed/python-sdk/examples/readme/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/examples/readme/poetry.lock +++ b/seed/python-sdk/examples/readme/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/examples/readme/src/seed/core/pydantic_utilities.py b/seed/python-sdk/examples/readme/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/examples/readme/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/examples/readme/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/additional_init_exports/poetry.lock b/seed/python-sdk/exhaustive/additional_init_exports/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/exhaustive/additional_init_exports/poetry.lock +++ b/seed/python-sdk/exhaustive/additional_init_exports/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/exhaustive/additional_init_exports/src/seed/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/additional_init_exports/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/additional_init_exports/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/additional_init_exports/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/aliases_with_validation/poetry.lock b/seed/python-sdk/exhaustive/aliases_with_validation/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/exhaustive/aliases_with_validation/poetry.lock +++ b/seed/python-sdk/exhaustive/aliases_with_validation/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/exhaustive/aliases_without_validation/poetry.lock b/seed/python-sdk/exhaustive/aliases_without_validation/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/exhaustive/aliases_without_validation/poetry.lock +++ b/seed/python-sdk/exhaustive/aliases_without_validation/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/exhaustive/custom-transport/poetry.lock b/seed/python-sdk/exhaustive/custom-transport/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/exhaustive/custom-transport/poetry.lock +++ b/seed/python-sdk/exhaustive/custom-transport/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/exhaustive/custom-transport/src/seed/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/custom-transport/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/custom-transport/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/custom-transport/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/datetime-milliseconds/poetry.lock b/seed/python-sdk/exhaustive/datetime-milliseconds/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/exhaustive/datetime-milliseconds/poetry.lock +++ b/seed/python-sdk/exhaustive/datetime-milliseconds/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/exhaustive/datetime-milliseconds/src/seed/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/datetime-milliseconds/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/datetime-milliseconds/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/datetime-milliseconds/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/deps_with_min_python_version/poetry.lock b/seed/python-sdk/exhaustive/deps_with_min_python_version/poetry.lock index d0a4581b9fe9..2110ecbe03b4 100644 --- a/seed/python-sdk/exhaustive/deps_with_min_python_version/poetry.lock +++ b/seed/python-sdk/exhaustive/deps_with_min_python_version/poetry.lock @@ -917,19 +917,19 @@ files = [ [[package]] name = "langchain" -version = "1.3.1" +version = "1.3.2" description = "Building applications with LLMs through composability" optional = false python-versions = "<4.0.0,>=3.10.0" groups = ["dev"] files = [ - {file = "langchain-1.3.1-py3-none-any.whl", hash = "sha256:154e9c30c90b391eba4315296f6bf6b6fac6b058ddea4cc771a10470968fe36f"}, - {file = "langchain-1.3.1.tar.gz", hash = "sha256:bc283c220233230f48b8e50ab1fbf1b688bcb206d933fa448d40a9b143177f62"}, + {file = "langchain-1.3.2-py3-none-any.whl", hash = "sha256:900f6b3f4ee08b9ba3cdbe667dbf42525bd6f66a4a07a7f1db26262673e41ed6"}, + {file = "langchain-1.3.2.tar.gz", hash = "sha256:ffd5f204a46b5fa1a38bf89ba3b45ca0902c02d18fa7d2a2eaeaeb1f5bf19d0a"}, ] [package.dependencies] langchain-core = ">=1.4.0,<2.0.0" -langgraph = ">=1.2.0,<1.3.0" +langgraph = ">=1.2.2,<1.3.0" pydantic = ">=2.7.4,<3.0.0" [package.extras] @@ -1008,14 +1008,14 @@ typing-extensions = ">=4.7.0,<5.0.0" [[package]] name = "langgraph" -version = "1.2.1" +version = "1.2.2" description = "Building stateful, multi-actor applications with LLMs" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "langgraph-1.2.1-py3-none-any.whl", hash = "sha256:5cc4020de8f1e2a048d773f6e9128646a2af8c68a8067ab9cab177a2fcc8d221"}, - {file = "langgraph-1.2.1.tar.gz", hash = "sha256:28314f844678d9d307cbd63e7b48b0145bf17177d84b40ee2921061e07b6f966"}, + {file = "langgraph-1.2.2-py3-none-any.whl", hash = "sha256:0a851bf4ba5939c5474a2fd57e6b439b5315283e254e42943bd392c2d71a5e03"}, + {file = "langgraph-1.2.2.tar.gz", hash = "sha256:f54a98458976b3ff0774683867df125fb52d8dbedeb2441d0b0656a51331cee5"}, ] [package.dependencies] @@ -1028,14 +1028,14 @@ xxhash = ">=3.5.0" [[package]] name = "langgraph-checkpoint" -version = "4.1.0" +version = "4.1.1" description = "Library with base interfaces for LangGraph checkpoint savers." optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "langgraph_checkpoint-4.1.0-py3-none-any.whl", hash = "sha256:8bc2a0466a20c38b865ce6671b42093fd5c041133f32351cae4222e0eeaf7fb5"}, - {file = "langgraph_checkpoint-4.1.0.tar.gz", hash = "sha256:e5bb304e30fc1363ac8fcb5f7dee5ca2185d77fe475b0d01de2c5f91324c2c21"}, + {file = "langgraph_checkpoint-4.1.1-py3-none-any.whl", hash = "sha256:25d29144b082827218e7bc3f1e9b0566a4bb007895cd6cc26f66a8428739f56e"}, + {file = "langgraph_checkpoint-4.1.1.tar.gz", hash = "sha256:6c2bdb530c91f91d7d9c1bd100925d0fc4f498d418c17f3587d1526279482a25"}, ] [package.dependencies] @@ -1060,14 +1060,14 @@ langgraph-checkpoint = ">=2.1.0,<5.0.0" [[package]] name = "langgraph-sdk" -version = "0.3.14" +version = "0.3.15" description = "SDK for interacting with LangGraph API" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "langgraph_sdk-0.3.14-py3-none-any.whl", hash = "sha256:68935bf6f4924eda92617a9e5dfb4f4281197508c648cb9d62ff083907607f9d"}, - {file = "langgraph_sdk-0.3.14.tar.gz", hash = "sha256:acd1674c538e97f3cdaa610f6dd7e34bc9bad30167f0ccc482dcd563325e81f5"}, + {file = "langgraph_sdk-0.3.15-py3-none-any.whl", hash = "sha256:3838773acf7456d158165385d49f48f1e856f28b56ccd99ea139a8f27004815d"}, + {file = "langgraph_sdk-0.3.15.tar.gz", hash = "sha256:29e805003d2c6e296823dd71992610976fd0428cefaa8b3304fd91f2247037de"}, ] [package.dependencies] @@ -1076,14 +1076,14 @@ orjson = ">=3.11.5" [[package]] name = "langsmith" -version = "0.8.5" +version = "0.8.6" description = "Client library to connect to the LangSmith Observability and Evaluation Platform." optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "langsmith-0.8.5-py3-none-any.whl", hash = "sha256:efc779f9d450dcaf9d97bc8894f4926276509d6e730e05289af9a64debce06ae"}, - {file = "langsmith-0.8.5.tar.gz", hash = "sha256:3615243d99c12f4047f13042bdc05a373dce232d106a6511b3ca7b48c5af1c2c"}, + {file = "langsmith-0.8.6-py3-none-any.whl", hash = "sha256:b304888ea5ec5fe397db24f0bf474b0c8e472fb23ee36a2007e9837f6ff29cc1"}, + {file = "langsmith-0.8.6.tar.gz", hash = "sha256:a46fd3403c2de3a9c34f72ebb7b2e45872627671adcc67c6a4c571520b6931cc"}, ] [package.dependencies] @@ -1094,6 +1094,7 @@ pydantic = ">=2,<3" requests = ">=2.0.0" requests-toolbelt = ">=1.0.0" uuid-utils = ">=0.12.0,<1.0" +websockets = ">=15.0" xxhash = ">=3.0.0" zstandard = ">=0.23.0" @@ -1104,7 +1105,6 @@ langsmith-pyo3 = ["langsmith-pyo3 (>=0.1.0rc2)"] openai-agents = ["openai-agents (>=0.0.3)"] otel = ["opentelemetry-api (>=1.30.0)", "opentelemetry-exporter-otlp-proto-http (>=1.30.0)", "opentelemetry-sdk (>=1.30.0)"] pytest = ["pytest (>=7.0.0)", "rich (>=13.9.4)", "vcrpy (>=7.0.0)"] -sandbox = ["websockets (>=15.0)"] strands-agents = ["opentelemetry-api (>=1.30.0)", "opentelemetry-exporter-otlp-proto-http (>=1.30.0)", "opentelemetry-sdk (>=1.30.0)", "strands-agents (>=0.1.0)", "strands-agents-tools (>=0.2.0)"] vcr = ["vcrpy (>=7.0.0)"] @@ -1861,23 +1861,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] @@ -2585,6 +2585,77 @@ files = [ {file = "uuid_utils-0.16.0.tar.gz", hash = "sha256:d6902d4375dfba4c9902c736bb82d3c040417b67f7d0fa48910ddfdb1ac95de7"}, ] +[[package]] +name = "websockets" +version = "16.0" +description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" +optional = false +python-versions = ">=3.10" +groups = ["dev"] +files = [ + {file = "websockets-16.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:04cdd5d2d1dacbad0a7bf36ccbcd3ccd5a30ee188f2560b7a62a30d14107b31a"}, + {file = "websockets-16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8ff32bb86522a9e5e31439a58addbb0166f0204d64066fb955265c4e214160f0"}, + {file = "websockets-16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:583b7c42688636f930688d712885cf1531326ee05effd982028212ccc13e5957"}, + {file = "websockets-16.0-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:7d837379b647c0c4c2355c2499723f82f1635fd2c26510e1f587d89bc2199e72"}, + {file = "websockets-16.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:df57afc692e517a85e65b72e165356ed1df12386ecb879ad5693be08fac65dde"}, + {file = "websockets-16.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:2b9f1e0d69bc60a4a87349d50c09a037a2607918746f07de04df9e43252c77a3"}, + {file = "websockets-16.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:335c23addf3d5e6a8633f9f8eda77efad001671e80b95c491dd0924587ece0b3"}, + {file = "websockets-16.0-cp310-cp310-win32.whl", hash = "sha256:37b31c1623c6605e4c00d466c9d633f9b812ea430c11c8a278774a1fde1acfa9"}, + {file = "websockets-16.0-cp310-cp310-win_amd64.whl", hash = "sha256:8e1dab317b6e77424356e11e99a432b7cb2f3ec8c5ab4dabbcee6add48f72b35"}, + {file = "websockets-16.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:31a52addea25187bde0797a97d6fc3d2f92b6f72a9370792d65a6e84615ac8a8"}, + {file = "websockets-16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:417b28978cdccab24f46400586d128366313e8a96312e4b9362a4af504f3bbad"}, + {file = "websockets-16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:af80d74d4edfa3cb9ed973a0a5ba2b2a549371f8a741e0800cb07becdd20f23d"}, + {file = "websockets-16.0-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:08d7af67b64d29823fed316505a89b86705f2b7981c07848fb5e3ea3020c1abe"}, + {file = "websockets-16.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7be95cfb0a4dae143eaed2bcba8ac23f4892d8971311f1b06f3c6b78952ee70b"}, + {file = "websockets-16.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d6297ce39ce5c2e6feb13c1a996a2ded3b6832155fcfc920265c76f24c7cceb5"}, + {file = "websockets-16.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1c1b30e4f497b0b354057f3467f56244c603a79c0d1dafce1d16c283c25f6e64"}, + {file = "websockets-16.0-cp311-cp311-win32.whl", hash = "sha256:5f451484aeb5cafee1ccf789b1b66f535409d038c56966d6101740c1614b86c6"}, + {file = "websockets-16.0-cp311-cp311-win_amd64.whl", hash = "sha256:8d7f0659570eefb578dacde98e24fb60af35350193e4f56e11190787bee77dac"}, + {file = "websockets-16.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:71c989cbf3254fbd5e84d3bff31e4da39c43f884e64f2551d14bb3c186230f00"}, + {file = "websockets-16.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8b6e209ffee39ff1b6d0fa7bfef6de950c60dfb91b8fcead17da4ee539121a79"}, + {file = "websockets-16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:86890e837d61574c92a97496d590968b23c2ef0aeb8a9bc9421d174cd378ae39"}, + {file = "websockets-16.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:9b5aca38b67492ef518a8ab76851862488a478602229112c4b0d58d63a7a4d5c"}, + {file = "websockets-16.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e0334872c0a37b606418ac52f6ab9cfd17317ac26365f7f65e203e2d0d0d359f"}, + {file = "websockets-16.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a0b31e0b424cc6b5a04b8838bbaec1688834b2383256688cf47eb97412531da1"}, + {file = "websockets-16.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:485c49116d0af10ac698623c513c1cc01c9446c058a4e61e3bf6c19dff7335a2"}, + {file = "websockets-16.0-cp312-cp312-win32.whl", hash = "sha256:eaded469f5e5b7294e2bdca0ab06becb6756ea86894a47806456089298813c89"}, + {file = "websockets-16.0-cp312-cp312-win_amd64.whl", hash = "sha256:5569417dc80977fc8c2d43a86f78e0a5a22fee17565d78621b6bb264a115d4ea"}, + {file = "websockets-16.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:878b336ac47938b474c8f982ac2f7266a540adc3fa4ad74ae96fea9823a02cc9"}, + {file = "websockets-16.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:52a0fec0e6c8d9a784c2c78276a48a2bdf099e4ccc2a4cad53b27718dbfd0230"}, + {file = "websockets-16.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e6578ed5b6981005df1860a56e3617f14a6c307e6a71b4fff8c48fdc50f3ed2c"}, + {file = "websockets-16.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:95724e638f0f9c350bb1c2b0a7ad0e83d9cc0c9259f3ea94e40d7b02a2179ae5"}, + {file = "websockets-16.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c0204dc62a89dc9d50d682412c10b3542d748260d743500a85c13cd1ee4bde82"}, + {file = "websockets-16.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:52ac480f44d32970d66763115edea932f1c5b1312de36df06d6b219f6741eed8"}, + {file = "websockets-16.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6e5a82b677f8f6f59e8dfc34ec06ca6b5b48bc4fcda346acd093694cc2c24d8f"}, + {file = "websockets-16.0-cp313-cp313-win32.whl", hash = "sha256:abf050a199613f64c886ea10f38b47770a65154dc37181bfaff70c160f45315a"}, + {file = "websockets-16.0-cp313-cp313-win_amd64.whl", hash = "sha256:3425ac5cf448801335d6fdc7ae1eb22072055417a96cc6b31b3861f455fbc156"}, + {file = "websockets-16.0-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:8cc451a50f2aee53042ac52d2d053d08bf89bcb31ae799cb4487587661c038a0"}, + {file = "websockets-16.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:daa3b6ff70a9241cf6c7fc9e949d41232d9d7d26fd3522b1ad2b4d62487e9904"}, + {file = "websockets-16.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:fd3cb4adb94a2a6e2b7c0d8d05cb94e6f1c81a0cf9dc2694fb65c7e8d94c42e4"}, + {file = "websockets-16.0-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:781caf5e8eee67f663126490c2f96f40906594cb86b408a703630f95550a8c3e"}, + {file = "websockets-16.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:caab51a72c51973ca21fa8a18bd8165e1a0183f1ac7066a182ff27107b71e1a4"}, + {file = "websockets-16.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:19c4dc84098e523fd63711e563077d39e90ec6702aff4b5d9e344a60cb3c0cb1"}, + {file = "websockets-16.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:a5e18a238a2b2249c9a9235466b90e96ae4795672598a58772dd806edc7ac6d3"}, + {file = "websockets-16.0-cp314-cp314-win32.whl", hash = "sha256:a069d734c4a043182729edd3e9f247c3b2a4035415a9172fd0f1b71658a320a8"}, + {file = "websockets-16.0-cp314-cp314-win_amd64.whl", hash = "sha256:c0ee0e63f23914732c6d7e0cce24915c48f3f1512ec1d079ed01fc629dab269d"}, + {file = "websockets-16.0-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:a35539cacc3febb22b8f4d4a99cc79b104226a756aa7400adc722e83b0d03244"}, + {file = "websockets-16.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:b784ca5de850f4ce93ec85d3269d24d4c82f22b7212023c974c401d4980ebc5e"}, + {file = "websockets-16.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:569d01a4e7fba956c5ae4fc988f0d4e187900f5497ce46339c996dbf24f17641"}, + {file = "websockets-16.0-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:50f23cdd8343b984957e4077839841146f67a3d31ab0d00e6b824e74c5b2f6e8"}, + {file = "websockets-16.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:152284a83a00c59b759697b7f9e9cddf4e3c7861dd0d964b472b70f78f89e80e"}, + {file = "websockets-16.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:bc59589ab64b0022385f429b94697348a6a234e8ce22544e3681b2e9331b5944"}, + {file = "websockets-16.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:32da954ffa2814258030e5a57bc73a3635463238e797c7375dc8091327434206"}, + {file = "websockets-16.0-cp314-cp314t-win32.whl", hash = "sha256:5a4b4cc550cb665dd8a47f868c8d04c8230f857363ad3c9caf7a0c3bf8c61ca6"}, + {file = "websockets-16.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b14dc141ed6d2dde437cddb216004bcac6a1df0935d79656387bd41632ba0bbd"}, + {file = "websockets-16.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:349f83cd6c9a415428ee1005cadb5c2c56f4389bc06a9af16103c3bc3dcc8b7d"}, + {file = "websockets-16.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:4a1aba3340a8dca8db6eb5a7986157f52eb9e436b74813764241981ca4888f03"}, + {file = "websockets-16.0-pp311-pypy311_pp73-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f4a32d1bd841d4bcbffdcb3d2ce50c09c3909fbead375ab28d0181af89fd04da"}, + {file = "websockets-16.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0298d07ee155e2e9fda5be8a9042200dd2e3bb0b8a38482156576f863a9d457c"}, + {file = "websockets-16.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:a653aea902e0324b52f1613332ddf50b00c06fdaf7e92624fbf8c77c78fa5767"}, + {file = "websockets-16.0-py3-none-any.whl", hash = "sha256:1637db62fad1dc833276dded54215f2c7fa46912301a24bd94d45d46a011ceec"}, + {file = "websockets-16.0.tar.gz", hash = "sha256:5f6261a5e56e8d5c42a4497b364ea24d94d9563e8fbd44e78ac40879c60179b5"}, +] + [[package]] name = "xxhash" version = "3.7.0" diff --git a/seed/python-sdk/exhaustive/deps_with_min_python_version/src/seed/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/deps_with_min_python_version/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/deps_with_min_python_version/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/deps_with_min_python_version/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/eager-imports/poetry.lock b/seed/python-sdk/exhaustive/eager-imports/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/exhaustive/eager-imports/poetry.lock +++ b/seed/python-sdk/exhaustive/eager-imports/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/exhaustive/eager-imports/src/seed/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/eager-imports/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/eager-imports/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/eager-imports/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/extra_dependencies/poetry.lock b/seed/python-sdk/exhaustive/extra_dependencies/poetry.lock index d0271863c3ed..19ff9457e4fb 100644 --- a/seed/python-sdk/exhaustive/extra_dependencies/poetry.lock +++ b/seed/python-sdk/exhaustive/extra_dependencies/poetry.lock @@ -1251,23 +1251,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/exhaustive/extra_dependencies/src/seed/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/extra_dependencies/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/extra_dependencies/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/extra_dependencies/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/extra_dev_dependencies/poetry.lock b/seed/python-sdk/exhaustive/extra_dev_dependencies/poetry.lock index 0a010ab2f2ef..84cb201b7fdb 100644 --- a/seed/python-sdk/exhaustive/extra_dev_dependencies/poetry.lock +++ b/seed/python-sdk/exhaustive/extra_dev_dependencies/poetry.lock @@ -1321,23 +1321,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/exhaustive/extra_dev_dependencies/src/seed/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/extra_dev_dependencies/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/extra_dev_dependencies/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/extra_dev_dependencies/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/five-second-timeout/poetry.lock b/seed/python-sdk/exhaustive/five-second-timeout/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/exhaustive/five-second-timeout/poetry.lock +++ b/seed/python-sdk/exhaustive/five-second-timeout/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/exhaustive/five-second-timeout/src/seed/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/five-second-timeout/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/five-second-timeout/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/five-second-timeout/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/follow_redirects_by_default/poetry.lock b/seed/python-sdk/exhaustive/follow_redirects_by_default/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/exhaustive/follow_redirects_by_default/poetry.lock +++ b/seed/python-sdk/exhaustive/follow_redirects_by_default/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/exhaustive/follow_redirects_by_default/src/seed/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/follow_redirects_by_default/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/follow_redirects_by_default/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/follow_redirects_by_default/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/import-paths/poetry.lock b/seed/python-sdk/exhaustive/import-paths/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/exhaustive/import-paths/poetry.lock +++ b/seed/python-sdk/exhaustive/import-paths/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/exhaustive/import-paths/src/seed/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/import-paths/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/import-paths/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/import-paths/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/improved_imports/poetry.lock b/seed/python-sdk/exhaustive/improved_imports/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/exhaustive/improved_imports/poetry.lock +++ b/seed/python-sdk/exhaustive/improved_imports/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/exhaustive/improved_imports/src/seed/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/improved_imports/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/improved_imports/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/improved_imports/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/infinite-timeout/poetry.lock b/seed/python-sdk/exhaustive/infinite-timeout/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/exhaustive/infinite-timeout/poetry.lock +++ b/seed/python-sdk/exhaustive/infinite-timeout/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/exhaustive/infinite-timeout/src/seed/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/infinite-timeout/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/infinite-timeout/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/infinite-timeout/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/inline-path-params/poetry.lock b/seed/python-sdk/exhaustive/inline-path-params/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/exhaustive/inline-path-params/poetry.lock +++ b/seed/python-sdk/exhaustive/inline-path-params/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/exhaustive/inline-path-params/src/seed/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/inline-path-params/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/inline-path-params/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/inline-path-params/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/inline_request_params/poetry.lock b/seed/python-sdk/exhaustive/inline_request_params/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/exhaustive/inline_request_params/poetry.lock +++ b/seed/python-sdk/exhaustive/inline_request_params/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/exhaustive/inline_request_params/src/seed/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/inline_request_params/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/inline_request_params/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/inline_request_params/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/no-custom-config/poetry.lock b/seed/python-sdk/exhaustive/no-custom-config/poetry.lock index fb3a0de06c97..45d91dca8189 100644 --- a/seed/python-sdk/exhaustive/no-custom-config/poetry.lock +++ b/seed/python-sdk/exhaustive/no-custom-config/poetry.lock @@ -1269,23 +1269,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/exhaustive/no-custom-config/src/seed/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/no-custom-config/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/no-custom-config/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/no-custom-config/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/output-directory-project-root/poetry.lock b/seed/python-sdk/exhaustive/output-directory-project-root/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/exhaustive/output-directory-project-root/poetry.lock +++ b/seed/python-sdk/exhaustive/output-directory-project-root/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/exhaustive/output-directory-project-root/src/seed/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/output-directory-project-root/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/output-directory-project-root/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/output-directory-project-root/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/output-directory-source-root-no-package-root/sub/dir/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/output-directory-source-root-no-package-root/sub/dir/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/output-directory-source-root-no-package-root/sub/dir/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/output-directory-source-root-no-package-root/sub/dir/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/output-directory-source-root-with-package-path/seed/my_org/my_sdk/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/output-directory-source-root-with-package-path/seed/my_org/my_sdk/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/output-directory-source-root-with-package-path/seed/my_org/my_sdk/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/output-directory-source-root-with-package-path/seed/my_org/my_sdk/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/output-directory-source-root/seed/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/output-directory-source-root/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/output-directory-source-root/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/output-directory-source-root/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/package-path/poetry.lock b/seed/python-sdk/exhaustive/package-path/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/exhaustive/package-path/poetry.lock +++ b/seed/python-sdk/exhaustive/package-path/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/exhaustive/package-path/src/seed/matryoshka/doll/structure/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/package-path/src/seed/matryoshka/doll/structure/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/package-path/src/seed/matryoshka/doll/structure/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/package-path/src/seed/matryoshka/doll/structure/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/pydantic-extra-fields/poetry.lock b/seed/python-sdk/exhaustive/pydantic-extra-fields/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/exhaustive/pydantic-extra-fields/poetry.lock +++ b/seed/python-sdk/exhaustive/pydantic-extra-fields/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/exhaustive/pydantic-extra-fields/src/seed/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/pydantic-extra-fields/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/pydantic-extra-fields/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/pydantic-extra-fields/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/pydantic-ignore-fields/poetry.lock b/seed/python-sdk/exhaustive/pydantic-ignore-fields/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/exhaustive/pydantic-ignore-fields/poetry.lock +++ b/seed/python-sdk/exhaustive/pydantic-ignore-fields/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/exhaustive/pydantic-ignore-fields/src/seed/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/pydantic-ignore-fields/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/pydantic-ignore-fields/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/pydantic-ignore-fields/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/pydantic-v1-with-utils/poetry.lock b/seed/python-sdk/exhaustive/pydantic-v1-with-utils/poetry.lock index 63fa2ad6c2c8..f2eeffef483a 100644 --- a/seed/python-sdk/exhaustive/pydantic-v1-with-utils/poetry.lock +++ b/seed/python-sdk/exhaustive/pydantic-v1-with-utils/poetry.lock @@ -977,132 +977,119 @@ email = ["email-validator (>=1.0.3)"] [[package]] name = "pydantic-core" -version = "2.46.4" +version = "2.47.0" description = "Core functionality for Pydantic validation and serialization" optional = false -python-versions = ">=3.9" +python-versions = ">=3.10" groups = ["main"] files = [ - {file = "pydantic_core-2.46.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:a396dcc17e5a0b164dbe026896245a4fa9ff402edca1dff0be3d53a517f74de4"}, - {file = "pydantic_core-2.46.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:da4b951fe36dc7c3a1ccb4e3cd1747c3542b8c9ceede8fc86cae054e764485f5"}, - {file = "pydantic_core-2.46.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb63e0198ca18aad131c089b9204c23079c3afa95487e561f4c522d519e55aba"}, - {file = "pydantic_core-2.46.4-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f47286a97f0bc9b8859519809077b91b2cefe4ae47fcbf5e466a009c1c5d742b"}, - {file = "pydantic_core-2.46.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:905a0ed8ea6f2d61c1738835f99b699348d7857379083e5fc497fa0c967a407c"}, - {file = "pydantic_core-2.46.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea793e075b70290d89d8142074262885d3f7da19634845135751bd6344f73b50"}, - {file = "pydantic_core-2.46.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:395aebd9183f9d112f569aeb5b2214d1a10a33bec8456447f7fbdfa51d38d4cd"}, - {file = "pydantic_core-2.46.4-cp310-cp310-manylinux_2_31_riscv64.whl", hash = "sha256:b078afbc25f3a1436c7a1d2cd3e322497ee99615ba97c563566fdf46aff1ee01"}, - {file = "pydantic_core-2.46.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f747929cf940cddb5b3668a390056ddd5ba2e5010615ea2dcf4f9c4f3ab8791d"}, - {file = "pydantic_core-2.46.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:daa27d92c36f24388fe3ad306b174781c747627f134452e4f128ea00ce1fe8c4"}, - {file = "pydantic_core-2.46.4-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:19e51f073cd3df251856a8a4189fbdf1de4012c3ebacfb1884f94f1eb406079f"}, - {file = "pydantic_core-2.46.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c1747f85cee84c26985853c6f3d9bd3e75da5212912443fa111c113b9c246f39"}, - {file = "pydantic_core-2.46.4-cp310-cp310-win32.whl", hash = "sha256:2f84c03c8607173d16b5a854ec68a2f9079ae03237a54fb506d13af47e1d018d"}, - {file = "pydantic_core-2.46.4-cp310-cp310-win_amd64.whl", hash = "sha256:8358a950c8909158e3df31538a7e4edc2d7265a7c54b47f0864d9e5bae9dcebf"}, - {file = "pydantic_core-2.46.4-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:0e96592440881c74a213e5ad528e2b24d3d4f940de2766bed9010ab1d9e51594"}, - {file = "pydantic_core-2.46.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e0d65b8c354be7fb5f720c3caa8bc940bc2d20ce749c8e06135f07f8ed95dd7c"}, - {file = "pydantic_core-2.46.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7bfb192b3f4b9e8a89b6277b6ce787564f62cfd272055f6e685726b111dc7826"}, - {file = "pydantic_core-2.46.4-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9037063db01f09b09e237c282b6792bd4da634b5402c4e7f0c61effed7701a04"}, - {file = "pydantic_core-2.46.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fc010ab034c8c7452522748bf937df58020d256ccae0874463d1f4d01758af8e"}, - {file = "pydantic_core-2.46.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8c5dac79fa1614d1e06ca695109c6105923bd9c7d1d6c918d4e637b7e6b32fd3"}, - {file = "pydantic_core-2.46.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9fa868638bf362d3d138ea55829cefb3d5f4b0d7f142234382a15e2485dbec4"}, - {file = "pydantic_core-2.46.4-cp311-cp311-manylinux_2_31_riscv64.whl", hash = "sha256:17299feefe090f2caa5b8e37222bb5f663e4935a8bfa6931d4102e5df1a9f398"}, - {file = "pydantic_core-2.46.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4c63ebc82684aa89d9a3bcbd13d515b3be44250dc68dd3bd81526c1cb31286c3"}, - {file = "pydantic_core-2.46.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:aaa2a54443eff1950ba5ddc6b6ccda0d9c84a364276a62f969bdf2a390650848"}, - {file = "pydantic_core-2.46.4-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:18e5ceec2ab67e6d5f1a9085e5a24c9c4e2ac4545730bfe668680bca05e555f3"}, - {file = "pydantic_core-2.46.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a0f62d0a58f4e7da165457e995725421e0064f2255d8eccebc49f41bbc23b109"}, - {file = "pydantic_core-2.46.4-cp311-cp311-win32.whl", hash = "sha256:041bde0a48fd37cf71cab1c9d56d3e8625a3793fef1f7dd232b3ff37e978ecda"}, - {file = "pydantic_core-2.46.4-cp311-cp311-win_amd64.whl", hash = "sha256:6f2eeda33a839975441c86a4119e1383c50b47faf0cbb5176985565c6bb02c33"}, - {file = "pydantic_core-2.46.4-cp311-cp311-win_arm64.whl", hash = "sha256:14f4c5d6db102bd796a627bbb3a17b4cf4574b9ae861d8b7c9a9661c6dd3362d"}, - {file = "pydantic_core-2.46.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:3245406455a5d98187ec35530fd772b1d799b26667980872c8d4614991e2c4a2"}, - {file = "pydantic_core-2.46.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:962ccbab7b642487b1d8b7df90ef677e03134cf1fd8880bf698649b22a69371f"}, - {file = "pydantic_core-2.46.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8233f2947cf85404441fd7e0085f53b10c93e0ee78611099b5c7237e36aacbf7"}, - {file = "pydantic_core-2.46.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3a233125ac121aa3ffba9a2b59edfc4a985a76092dc8279586ab4b71390875e7"}, - {file = "pydantic_core-2.46.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5b712b53160b79a5850310b912a5ef8e57e56947c8ad690c227f5c9d7e561712"}, - {file = "pydantic_core-2.46.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9401557acd873c3a7f3eb9383edef8ac4968f9510e340f4808d427e75667e7b4"}, - {file = "pydantic_core-2.46.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:926c9541b14b12b1681dca8a0b75feb510b06c6341b70a8e500c2fdcff837cce"}, - {file = "pydantic_core-2.46.4-cp312-cp312-manylinux_2_31_riscv64.whl", hash = "sha256:56cb4851bcaf3d117eddcef4fe66afd750a50274b0da8e22be256d10e5611987"}, - {file = "pydantic_core-2.46.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c68fcd102d71ea85c5b2dfac3f4f8476eff42a9e078fd5faefff6d145063536b"}, - {file = "pydantic_core-2.46.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b2f69dec1725e79a012d920df1707de5caf7ed5e08f3be4435e25803efc47458"}, - {file = "pydantic_core-2.46.4-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:8d0820e8192167f80d88d64038e609c31452eeca865b4e1d9950a27a4609b00b"}, - {file = "pydantic_core-2.46.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:fbdb89b3e1c94a30cc5edfce477c6e6a5dc4d8f84665b455c27582f211a1c72c"}, - {file = "pydantic_core-2.46.4-cp312-cp312-win32.whl", hash = "sha256:9aa768456404a8bf48a4406685ac2bec8e72b62c69313734fa3b73cf33b3a894"}, - {file = "pydantic_core-2.46.4-cp312-cp312-win_amd64.whl", hash = "sha256:e9c26f834c65f5752f3f06cb08cb86a913ceb7274d0db6e267808a708b46bc89"}, - {file = "pydantic_core-2.46.4-cp312-cp312-win_arm64.whl", hash = "sha256:4fc73cb559bdb54b1134a706a2802a4cddd27a0633f5abb7e53056268751ac6a"}, - {file = "pydantic_core-2.46.4-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:5d5902252db0d3cedf8d4a1bc68f70eeb430f7e4c7104c8c476753519b423008"}, - {file = "pydantic_core-2.46.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c94f0688e7b8d0a67abf40e57a7eaaecd17cc9586706a31b76c031f63df052b4"}, - {file = "pydantic_core-2.46.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f027324c56cd5406ca49c124b0db10e56c69064fec039acc571c29020cc87c76"}, - {file = "pydantic_core-2.46.4-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e739fee756ba1010f8bcccb534252e85a35fe45ae92c295a06059ce58b74ccd3"}, - {file = "pydantic_core-2.46.4-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9d56801be94b86a9da183e5f3766e6310752b99ff647e38b09a9500d88e46e76"}, - {file = "pydantic_core-2.46.4-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2412e734dcb48da14d4e4006b82b46b74f2518b8a26ee7e58c6844a6cd6d03c4"}, - {file = "pydantic_core-2.46.4-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9551187363ffc0de2a00b2e47c25aeaeb1020b69b668762966df15fc5659dd5a"}, - {file = "pydantic_core-2.46.4-cp313-cp313-manylinux_2_31_riscv64.whl", hash = "sha256:0186750b482eefa11d7f435892b09c5c606193ef3375bcf94aa00ae6bfb66262"}, - {file = "pydantic_core-2.46.4-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5855698a4856556d86e8e6cd8434bc3ac0314ee8e12089ae0e143f64c6256e4e"}, - {file = "pydantic_core-2.46.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:cbaf13819775b7f769bf4a1f066cb6df7a28d4480081a589828ef190226881cd"}, - {file = "pydantic_core-2.46.4-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:633147d34cf4550417f12e2b1a0383973bdf5cdfde212cb09e9a581cf10820be"}, - {file = "pydantic_core-2.46.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:82cf5301172168103724d49a1444d3378cb20cdee30b116a1bd6031236298a5d"}, - {file = "pydantic_core-2.46.4-cp313-cp313-win32.whl", hash = "sha256:9fa8ae11da9e2b3126c6426f147e0fba88d96d65921799bb30c6abd1cb2c97fb"}, - {file = "pydantic_core-2.46.4-cp313-cp313-win_amd64.whl", hash = "sha256:6b3ace8194b0e5204818c92802dcdca7fc6d88aabbb799d7c795540d9cd6d292"}, - {file = "pydantic_core-2.46.4-cp313-cp313-win_arm64.whl", hash = "sha256:184c081504d17f1c1066e430e117142b2c77d9448a97f7b65c6ac9fd9aee238d"}, - {file = "pydantic_core-2.46.4-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:428e04521a40150c85216fc8b85e8d39fece235a9cf5e383761238c7fa9b96fb"}, - {file = "pydantic_core-2.46.4-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:23ace664830ee0bfe014a0c7bc248b1f7f25ed7ad103852c317624a1083af462"}, - {file = "pydantic_core-2.46.4-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce5c1d2a8b27468f433ca974829c44060b8097eedc39933e3c206a90ee49c4a9"}, - {file = "pydantic_core-2.46.4-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7283d57845ecf5a163403eb0702dfc220cc4fbdd18919cb5ccea4f95ee1cdab4"}, - {file = "pydantic_core-2.46.4-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8daafc69c93ee8a0204506a3b6b30f586ef54028f52aeeeb5c4cfc5184fd5914"}, - {file = "pydantic_core-2.46.4-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd2213145bcc2ba85884d0ac63d222fece9209678f77b9b4d76f054c561adb28"}, - {file = "pydantic_core-2.46.4-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a5f930472650a82629163023e630d160863fce524c616f4e5186e5de9d9a49b"}, - {file = "pydantic_core-2.46.4-cp314-cp314-manylinux_2_31_riscv64.whl", hash = "sha256:c1b3f518abeca3aa13c712fd202306e145abf59a18b094a6bafb2d2bbf59192c"}, - {file = "pydantic_core-2.46.4-cp314-cp314-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1a7dd0b3ee80d90150e3495a3a13ac34dbcbfd4f012996a6a1d8900e91b5c0fb"}, - {file = "pydantic_core-2.46.4-cp314-cp314-musllinux_1_1_aarch64.whl", hash = "sha256:3fb702cd90b0446a3a1c5e470bfa0dd23c0233b676a9099ddcc964fa6ca13898"}, - {file = "pydantic_core-2.46.4-cp314-cp314-musllinux_1_1_armv7l.whl", hash = "sha256:b8458003118a712e66286df6a707db01c52c0f52f7db8e4a38f0da1d3b94fc4e"}, - {file = "pydantic_core-2.46.4-cp314-cp314-musllinux_1_1_x86_64.whl", hash = "sha256:372429a130e469c9cd698925ce5fc50940b7a1336b0d82038e63d5bbc4edc519"}, - {file = "pydantic_core-2.46.4-cp314-cp314-win32.whl", hash = "sha256:85bb3611ff1802f3ee7fdd7dbff26b56f343fb432d57a4728fdd49b6ef35e2f4"}, - {file = "pydantic_core-2.46.4-cp314-cp314-win_amd64.whl", hash = "sha256:811ff8e9c313ab425368bcbb36e5c4ebd7108c2bbf4e4089cfbb0b01eff63fac"}, - {file = "pydantic_core-2.46.4-cp314-cp314-win_arm64.whl", hash = "sha256:bfec22eab3c8cc2ceec0248aec886624116dc079afa027ecc8ad4a7e62010f8a"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:af8244b2bef6aaad6d92cda81372de7f8c8d36c9f0c3ea36e827c60e7d9467a0"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:5a4330cdbc57162e4b3aa303f588ba752257694c9c9be3e7ebb11b4aca659b5d"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29c61fc04a3d840155ff08e475a04809278972fe6aef51e2720554e96367e34b"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c50f2528cf200c5eed56faf3f4e22fcd5f38c157a8b78576e6ba3168ec35f000"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0cbe8b01f948de4286c74cdd6c667aceb38f5c1e26f0693b3983d9d74887c65e"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:617d7e2ca7dcb8c5cf6bcb8c59b8832c94b36196bbf1cbd1bfb56ed341905edd"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7027560ee92211647d0d34e3f7cd6f50da56399d26a9c8ad0da286d3869a53f3"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-manylinux_2_31_riscv64.whl", hash = "sha256:f99626688942fb746e545232e7726926f3be91b5975f8b55327665fafda991c7"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:fc3e9034a63de20e15e8ade85358bc6efc614008cab72898b4b4952bea0509ff"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-musllinux_1_1_aarch64.whl", hash = "sha256:97e7cf2be5c77b7d1a9713a05605d49460d02c6078d38d8bef3cbe323c548424"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-musllinux_1_1_armv7l.whl", hash = "sha256:3bf92c5d0e00fefaab325a4d27828fe6b6e2a21848686b5b60d2d9eeb09d76c6"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-musllinux_1_1_x86_64.whl", hash = "sha256:3ecbc122d18468d06ca279dc26a8c2e2d5acb10943bb35e36ae92096dc3b5565"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-win32.whl", hash = "sha256:e846ae7835bf0703ae43f534ab79a867146dadd59dc9ca5c8b53d5c8f7c9ef02"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-win_amd64.whl", hash = "sha256:2108ba5c1c1eca18030634489dc544844144ee36357f2f9f780b93e7ddbb44b5"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-win_arm64.whl", hash = "sha256:4fcbe087dbc2068af7eda3aa87634eba216dbda64d1ae73c8684b621d33f6596"}, - {file = "pydantic_core-2.46.4-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:fd8b3d9fd264be37976686c7f65cd52a83f5e84f4bfd2adf9c1d469676bbb6ae"}, - {file = "pydantic_core-2.46.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9f444c499b3eefd3a92e348059471ea0c3a6e303d9c1cec09fa748fd9f895201"}, - {file = "pydantic_core-2.46.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3447661d99f75a3683a4cf5c87da72f2161964611864dbbeac7fbb118bb4bfc0"}, - {file = "pydantic_core-2.46.4-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8b9bab013d1c7a79d3501ff86d0bc9c31bf587db4551677b96bec07df78c6b15"}, - {file = "pydantic_core-2.46.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d995260fdf4e1db774581b4900e0f832abe3c7c84996726bbc161b19c8f29e76"}, - {file = "pydantic_core-2.46.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f13a646d65d09fbf1bc6b3a9635d30095c8e7e5cc419ff35ecc563c5fd04cd49"}, - {file = "pydantic_core-2.46.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:432c179df7874eeb73307aad2df0755e1ae0efa61ff0ea89b93e194411ae3928"}, - {file = "pydantic_core-2.46.4-cp39-cp39-manylinux_2_31_riscv64.whl", hash = "sha256:e68b7a074f65a2fd746c52a7ce6142ab7006074ac269ace0c25cd8ba171f8066"}, - {file = "pydantic_core-2.46.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4a05d69cba51d852c5c3e92758653245a50c0b646ced0cf05bd793ed592839d6"}, - {file = "pydantic_core-2.46.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:228ee9bae8bef5b1e97ec58302f80357c37199e0d0a99174e138d28e6957b9d9"}, - {file = "pydantic_core-2.46.4-cp39-cp39-musllinux_1_1_armv7l.whl", hash = "sha256:10e17cbb10a330363733efc4d7c4d0dd827ac0909b8f6a6542298fed1ea62f29"}, - {file = "pydantic_core-2.46.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:91a06d2e259ecfbd8c901d70c3c507900458498142b3026a296b7de4d1322cc9"}, - {file = "pydantic_core-2.46.4-cp39-cp39-win32.whl", hash = "sha256:d80ee3d731373b24cebbc10d689ca4ee1875caf0d5703a245db18efd4dd37fc1"}, - {file = "pydantic_core-2.46.4-cp39-cp39-win_amd64.whl", hash = "sha256:3be77f45df024d789a672ae34f8b06fb346c4f9f46ea714956660ea4862e89ac"}, - {file = "pydantic_core-2.46.4-graalpy311-graalpy242_311_native-macosx_10_12_x86_64.whl", hash = "sha256:14d4edf427bdcf950a8a02d7cb44a08614388dd6e1bdcbf4f67504fa7887da9c"}, - {file = "pydantic_core-2.46.4-graalpy311-graalpy242_311_native-macosx_11_0_arm64.whl", hash = "sha256:0ce40cd7b21210e99342afafbd4d0f76d784eb5b1d60f3bdc566be4983c6c73b"}, - {file = "pydantic_core-2.46.4-graalpy311-graalpy242_311_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:90884113d8b48f760e9587002789ddd741e76ab9f89518cd1e43b1f1a52ec44b"}, - {file = "pydantic_core-2.46.4-graalpy311-graalpy242_311_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66ce7632c22d837c95301830e111ad0128a32b8207533b60896a96c4915192ea"}, - {file = "pydantic_core-2.46.4-graalpy312-graalpy250_312_native-macosx_10_12_x86_64.whl", hash = "sha256:1d8ba486450b14f3b1d63bc521d410ec7565e52f887b9fb671791886436a42f7"}, - {file = "pydantic_core-2.46.4-graalpy312-graalpy250_312_native-macosx_11_0_arm64.whl", hash = "sha256:3009f12e4e90b7f88b4f9adb1b0c4a3d58fe7820f3238c190047209d148026df"}, - {file = "pydantic_core-2.46.4-graalpy312-graalpy250_312_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad785e92e6dc634c21555edc8bd6b64957ab844541bcb96a1366c202951ae526"}, - {file = "pydantic_core-2.46.4-graalpy312-graalpy250_312_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:00c603d540afdd6b80eb39f078f33ebd46211f02f33e34a32d9f053bba711de0"}, - {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:0c563b08bca408dc7f65f700633d8442fffb2421fc47b8101377e9fd65051ff0"}, - {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:db06ffe51636ffe9ca531fe9023dd64bdd794be8754cb5df57c5498ae5b518a7"}, - {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:133878133d271ade3d41d1bfb2a45ec38dbdbda40bc065921c6b04e4630127e2"}, - {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9bc519fbf2b7578398853d815009ae5e4d4603d12f4e3f91da8c06852d3da3e9"}, - {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:c7a7bd4e39e8e4c12c39cd480356842b6a8a06e41b23a55a5e3e191718838ddf"}, - {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:d396ec2b979760aaf3218e76c24e65bd0aca24983298653b3a9d7a45f9e47b30"}, - {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:86e1a4418c6cd97d60c95c71164158eaf7324fae7b0923264016baa993eba6fc"}, - {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:d51026d73fcfd93610abc7b27789c26b313920fcfb20e27462d74a7f8b06e983"}, - {file = "pydantic_core-2.46.4.tar.gz", hash = "sha256:62f875393d7f270851f20523dd2e29f082bcc82292d66db2b64ea71f64b6e1c1"}, + {file = "pydantic_core-2.47.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:d4c7148fc6c0bb727139010e15aab198be6c5d00276f83246b417ce69831f9d6"}, + {file = "pydantic_core-2.47.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:497f91a0499fa4ce7ae982756f8a237af19f145d944258c0c991cfb78aee13b0"}, + {file = "pydantic_core-2.47.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6cd3cde901878cce06787608a50c4456b8ad49c2128440c24b96b5624d26937"}, + {file = "pydantic_core-2.47.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6b02b17e44bcb066b9f3a1d31c0be01a59f81d0b94b5066fdded1a8ccc8f819e"}, + {file = "pydantic_core-2.47.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e42feab6c93fa3264708502f5062147073c7e57bc56bc1c44ed00efa53bc1859"}, + {file = "pydantic_core-2.47.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a3c450c65dfd14eee570756f61c823f8a7a36a3f4f4d46a3945d6225dc8a47d8"}, + {file = "pydantic_core-2.47.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0744954b81a77cfd381760d8b1bf92ba4db57d2da235e695d6fd3c94f741d24d"}, + {file = "pydantic_core-2.47.0-cp310-cp310-manylinux_2_31_riscv64.whl", hash = "sha256:34336a9cdd8b54e0cbc9d3c36b7be7dd828fa10e4209a7e4b0ca4583aa3e696d"}, + {file = "pydantic_core-2.47.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8623f4e598c9cee799075d0f1ab5174f9f2e7c42b3c5c7d859a97b3c726c84f8"}, + {file = "pydantic_core-2.47.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6143463422e3851187657ff677af0cb04c5e1a2c51c028438ce5f20fbb5cb50d"}, + {file = "pydantic_core-2.47.0-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:152bcbfe9f087716d185f6003be549f2cc6ee3cd4ca67909118e31626afc209b"}, + {file = "pydantic_core-2.47.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8ad7522d6e0cf28192b30f4c9db62e9b7a13ef10652bf8310de27bcb4a6c1c40"}, + {file = "pydantic_core-2.47.0-cp310-cp310-win32.whl", hash = "sha256:e1c8ea447dcfaa7f7d815d07bddb131383275682601878b5711f59fac68045a2"}, + {file = "pydantic_core-2.47.0-cp310-cp310-win_amd64.whl", hash = "sha256:df82086e6efb002a8e4f8f787dd2ddf9db46403fe8697b7620111663799b62b8"}, + {file = "pydantic_core-2.47.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:fe87ccbc39a103709d0a5afa75240c15a94611af129261e9484bef0bc97960b2"}, + {file = "pydantic_core-2.47.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:45882c24324f123037982c65eb8d60da778447e6bc87c82241f81d6c6d2c307e"}, + {file = "pydantic_core-2.47.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:983e39b39547772543f3518557d0a86dbb3b7bb58bf8e82faba1e0cfa3e816d9"}, + {file = "pydantic_core-2.47.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3eb92447d3a079b945b61b8cbd6c3ec2954de3655c4efa0ebd35b069e472c2a9"}, + {file = "pydantic_core-2.47.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3af9600c680bec4b8d23c32ddaf7a5d91ed39a2cf758c082e34e860140cdcd87"}, + {file = "pydantic_core-2.47.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9d6098342d4510a9034a500a53b1d737daf9cfc18a47cd21047d02d7d1587557"}, + {file = "pydantic_core-2.47.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9a891c20be5110deb1904f639f3615ec5022b3495995850d1abe7b8fa1550b5"}, + {file = "pydantic_core-2.47.0-cp311-cp311-manylinux_2_31_riscv64.whl", hash = "sha256:13990d357a50078e382b15fa3ce3f08043223b4be3eaeb340b184f54c1a2397a"}, + {file = "pydantic_core-2.47.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a0f7343305387bb5884f24d384b7978ad099a277b27529e592c041a502a37c32"}, + {file = "pydantic_core-2.47.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:2c5e11bb1be2de2707c9367f364e73091ef30d34be54b3a4564d7421ac1a16bd"}, + {file = "pydantic_core-2.47.0-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:e1b52aa981e034896712460c899ee30707c8c6a385e79bb7648aac76c748a3da"}, + {file = "pydantic_core-2.47.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d57d46021c20d4efc28f69b4ca4670dffbb7bdafa51d1967b747849726ec643b"}, + {file = "pydantic_core-2.47.0-cp311-cp311-win32.whl", hash = "sha256:d93c02ae8bd33f73624319d85cba47e754155c5bf104c0c5ca96fcd1f3094939"}, + {file = "pydantic_core-2.47.0-cp311-cp311-win_amd64.whl", hash = "sha256:859ca679f00e5feb11b58b616eb7bc0efbb13654be21f5c898e510e27671c900"}, + {file = "pydantic_core-2.47.0-cp311-cp311-win_arm64.whl", hash = "sha256:482667e5b7a3e97b0836f33a716199c4ec6ba9c896ca4db6eae799ec527c1e64"}, + {file = "pydantic_core-2.47.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:70a7aeba54854f5d97da65cb1a61f000f53df3704cab41cd81d65ab127ddd031"}, + {file = "pydantic_core-2.47.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c0e97329e38228f57fe1f2d91ba0ef39cc75cc1a84fe6ef58942d2fc6cb406bf"}, + {file = "pydantic_core-2.47.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:faf83e50714837af72f13e9369c50377552a4a74049d4477bed51c7e5822d94b"}, + {file = "pydantic_core-2.47.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f50abe60347bf8afe2b2f58db86cf3ac6e418eec7ffa01d9dc90ba29fc64f243"}, + {file = "pydantic_core-2.47.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f13b18e7dec056336f29ee77dea3cc5db0271d6215cac7249cc5c61b0a49d293"}, + {file = "pydantic_core-2.47.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3aaabbdbbbb8dff33fa053ffb2c980f39dec745fc03592f50e1e010449129841"}, + {file = "pydantic_core-2.47.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d2907ee6d15cf26787bfbeb4c42e18e52f358086eab91baea961a0d909248d6"}, + {file = "pydantic_core-2.47.0-cp312-cp312-manylinux_2_31_riscv64.whl", hash = "sha256:c413761260967f4dbb51135e1b49f30a1c29e15bf371fbb39754ed6475739545"}, + {file = "pydantic_core-2.47.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2bb527ac6e5a9721023b24615e1a55c01f47c60007f08b7d2afb89ff9c7a0e22"}, + {file = "pydantic_core-2.47.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b3abf7fd5e6abe483a63413f9cd26b7c93c20780e19c8556434c7279f6b2f10c"}, + {file = "pydantic_core-2.47.0-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:9cefe43d17a5a273d71697c084d3787defa7f578cd5fab4cef4c66d13c9e44b2"}, + {file = "pydantic_core-2.47.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:38a8b5371b7938e4f6c31060dbd51d3b3229aef7c43eaac2eb8b153e43c3f189"}, + {file = "pydantic_core-2.47.0-cp312-cp312-win32.whl", hash = "sha256:b87e95e644df2a36bd631dd0d6e097aa73d19a55adf7b1724ebdeece3d9c76b7"}, + {file = "pydantic_core-2.47.0-cp312-cp312-win_amd64.whl", hash = "sha256:a0078c5695322050ccedbc86eafaf3e2548439782c51d99e575de0e31b9fe4f4"}, + {file = "pydantic_core-2.47.0-cp312-cp312-win_arm64.whl", hash = "sha256:7eedc31996e9eba3bdfbbc380805ac6d765c889b7e93b17cd00ecb0200fd6dca"}, + {file = "pydantic_core-2.47.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:263560ece98bffbbc0a8047ce60b8a278c859db6a2a4e30d9454b02891045eca"}, + {file = "pydantic_core-2.47.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a7fdaff39a66bf66e9037da482575513d2f20bfb02ea9d9222b5cb3b902fc695"}, + {file = "pydantic_core-2.47.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b3e6c8ee5ce8c270bfae09763ae4bbbccfe81090c97d670a621fb86cb1ef6042"}, + {file = "pydantic_core-2.47.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e38cdae682cfec4b3816722dccf6376ca59049726d57dca83c2fe7cc13665589"}, + {file = "pydantic_core-2.47.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9fc3193ff0b7e2e168f84c6185e70475738c191f3154e0af8f897cd0f8f9a489"}, + {file = "pydantic_core-2.47.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:57ff41672a615f38af528ee904602be51c653248354e5db8e9252668abe91e68"}, + {file = "pydantic_core-2.47.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:473b9a2b2a1f0dd55cbb32d2b902f93babe7f141a0bb48fb4d3d4d2b3e93e9a0"}, + {file = "pydantic_core-2.47.0-cp313-cp313-manylinux_2_31_riscv64.whl", hash = "sha256:195f9c4ac43a7b2a044a7b86631c3352abdb820bed2823ea29f98f779255f459"}, + {file = "pydantic_core-2.47.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6c1cd10d39ef1ff8bcd68b6865bee9c434631ac0608d402fe86e678851c2e2a5"}, + {file = "pydantic_core-2.47.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:7cbe66352fe2b39511d49150e5b52159429cd21f5633a3e801dd2c43829dcdca"}, + {file = "pydantic_core-2.47.0-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:e35192a1d53e55d510d8bb1023c988c7cdae6d94539074971741b2a7656e49a1"}, + {file = "pydantic_core-2.47.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:81de54576de2e20baec76cc5afae2820f9049e6fdc4f357bac3391da02d0ba97"}, + {file = "pydantic_core-2.47.0-cp313-cp313-win32.whl", hash = "sha256:05da6647bdfd3888936ac10aa39b239d659f3c93dff281af0fc5943eb55629dd"}, + {file = "pydantic_core-2.47.0-cp313-cp313-win_amd64.whl", hash = "sha256:021220e0a03b66112737ee1fc49759340ce8fafb8d9ade1b7fb366b06033fa45"}, + {file = "pydantic_core-2.47.0-cp313-cp313-win_arm64.whl", hash = "sha256:55156ee2f6f561ea4e25ab55f84bd70b9c9ed2546a834cb2b038fe10225aaa37"}, + {file = "pydantic_core-2.47.0-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:6e37a6974fbd8fa7cae12285a76970d50b3689ffd6ed7c7fdd176ba81dd22d0e"}, + {file = "pydantic_core-2.47.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:2433b8524785cc117e602233bc574879bc8d87f09523edeec51665d5c46cf42d"}, + {file = "pydantic_core-2.47.0-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f8f9f2be064c8bf1189f46f7062fd42765d94f59cfb7db7ef8db19563192110a"}, + {file = "pydantic_core-2.47.0-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1f0a9659a2eb161573418e3138f616101ba21bbd2ff04916dca7b6712155e015"}, + {file = "pydantic_core-2.47.0-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2995074b99242aa28991e0120a3c881babc139e08750a05b7ea7d140644e091d"}, + {file = "pydantic_core-2.47.0-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c5b224dc04c3ff9b08c24419464eb7f6ad7a1049e12284a00bf80df82bd15fdb"}, + {file = "pydantic_core-2.47.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:264361b7236d4374fef6342908f87d084a0d58a2f8d0811e99f714309cb0ba7e"}, + {file = "pydantic_core-2.47.0-cp314-cp314-manylinux_2_31_riscv64.whl", hash = "sha256:53368beaf693f6302a6e33bdefe950857534a04d282811421bd20176d0fb5636"}, + {file = "pydantic_core-2.47.0-cp314-cp314-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5d2177b44ba7d9d86850f865f362feeaac6a2ed8517a9b505b97ff0b7fdbd7dd"}, + {file = "pydantic_core-2.47.0-cp314-cp314-musllinux_1_1_aarch64.whl", hash = "sha256:fb57a6b538ec7a01b937986bc093aec530fb056135b6bc9cfdd0bf8460c25bc2"}, + {file = "pydantic_core-2.47.0-cp314-cp314-musllinux_1_1_armv7l.whl", hash = "sha256:18c9c7c3a18e9bdbf1215d913f6bd00e17595dc92949817935cb87a3cf5f1697"}, + {file = "pydantic_core-2.47.0-cp314-cp314-musllinux_1_1_x86_64.whl", hash = "sha256:2f72a382886ca85bb1247303b9134cf9978c9d454de62e710a1ebcd9d2131927"}, + {file = "pydantic_core-2.47.0-cp314-cp314-pyemscripten_2026_0_wasm32.whl", hash = "sha256:cdf4dc2cdd0eacad1bd81c4d25422b4c25b206acae095d2d64e5d5cb7facc6b3"}, + {file = "pydantic_core-2.47.0-cp314-cp314-win32.whl", hash = "sha256:1e859dd5e06e9807080e14995db131649a77c61131cc464a7fe492a69ce82488"}, + {file = "pydantic_core-2.47.0-cp314-cp314-win_amd64.whl", hash = "sha256:234ecade0e358caa1ea516c218b3f61e61e30532cad1a8bb12f2487325838548"}, + {file = "pydantic_core-2.47.0-cp314-cp314-win_arm64.whl", hash = "sha256:58158d0111e86893bc35aacabe509f951ed303cddf8cdba43533190bde317914"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:24017fd3befd4d7cc4a8c71f4a1e9a44d29fdc91723c5446b0e795ab808adee7"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:dfa0820888cc4549fcce7e6bd8affa7d75198d885ccd0bb0760def4bb8461862"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1d507f331756f7066cde7c9f35ed78fa78223a54369dbc8a34d6da7a5074fa1"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5c4a885fb50c05903bc703a00830616e680304fdcdd90fc9535a52e72debe712"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7b107cbd764bf68f12a57c7aa5846d868bc7463490a1ac2d0f19bffef624c5a9"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0b5f46412c8226d0c8f1f423c324c75afe342e4b854836933579fb484f68598c"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90cbf7e35d597503cbdb5cd85409cbb75f377290bc7e8e37cc5dfe4f5cc66cf8"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-manylinux_2_31_riscv64.whl", hash = "sha256:04ee7ba7172cb4484af51b2890f19069d35773698abc8c6ebb651f52fbf41134"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2a7acb3e120ddba94372b8146e62ea3a0bce203180e34641c817c87f995c91e0"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-musllinux_1_1_aarch64.whl", hash = "sha256:1a2f7ceb58013d167d8c96f10ec9b3a137018c819ba356f68ff1cb74302fd22e"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-musllinux_1_1_armv7l.whl", hash = "sha256:482b097637fec5037eb13fe9f9d5fe47e568a5b451d686bc2b854076cc0b50ca"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-musllinux_1_1_x86_64.whl", hash = "sha256:a8c7f5fad73eb404f4b84c75f3d9d3865b748ded248b7366341db6e516fc502b"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-win32.whl", hash = "sha256:f343c39928097175acf2f7d0cba5c00b0f62265d88a173a4ce264266ac849bd9"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-win_amd64.whl", hash = "sha256:52d40e074da44e42b2425aedebd513e405a31807036ef597175117a9b01743a6"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-win_arm64.whl", hash = "sha256:25cac08c9735e61e5c0b9f7a85c438661fc0e9da226afdfa984c3da6c5942a5a"}, + {file = "pydantic_core-2.47.0-graalpy311-graalpy242_311_native-macosx_10_12_x86_64.whl", hash = "sha256:dbe50ffd50c7e1b8e3bedafafd10fabe8a7cf51916f995fd57316eb1293f2439"}, + {file = "pydantic_core-2.47.0-graalpy311-graalpy242_311_native-macosx_11_0_arm64.whl", hash = "sha256:67c3f4b57a1cda846a9b6258e0ef70f99b0757d0b6a55f7e3e5b100620937d2c"}, + {file = "pydantic_core-2.47.0-graalpy311-graalpy242_311_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:33c1e025c83b3987784191ec5aa78cfb94686f1ba73d98ae4d520c09cdecda8d"}, + {file = "pydantic_core-2.47.0-graalpy311-graalpy242_311_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c547cac87585ab8a6e9495fd3be22ae8ccaabb3c909f6ae589a180677c105cc8"}, + {file = "pydantic_core-2.47.0-graalpy312-graalpy250_312_native-macosx_10_12_x86_64.whl", hash = "sha256:d648515c65f8871ceca6a0446be32fb1a0aae414db3907ec0df6cc2380dd0c04"}, + {file = "pydantic_core-2.47.0-graalpy312-graalpy250_312_native-macosx_11_0_arm64.whl", hash = "sha256:3292813ac8ed17671a5b1ac8256b6e98b96c29c1c6d061c35899e2f6e0444cac"}, + {file = "pydantic_core-2.47.0-graalpy312-graalpy250_312_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf3cfcea028f41a9b42a47f4a53d72ca4274d04e3ee525bd58ca89ea8c5e1910"}, + {file = "pydantic_core-2.47.0-graalpy312-graalpy250_312_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:52afdd8640d59890539060e91c8a494cf96b463e63b2ba499cc5c23abcb5e96b"}, + {file = "pydantic_core-2.47.0-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:ff201e8b93cd6a3849ec3fecb1d642d0391ca4a387f1162adaadddfd1986598c"}, + {file = "pydantic_core-2.47.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:eba39fca6e1fc7c58a5e6b5c320c6eec86014f7075dfa0fb1a79076e00304b3a"}, + {file = "pydantic_core-2.47.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5b494566e8e7e1a3992a550900902d6c0051538e53f5957e24003e7775638a24"}, + {file = "pydantic_core-2.47.0-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c96f745745fe778a92e74f40675c57b80dd46fc98186c9020e5ba4514a0470bd"}, + {file = "pydantic_core-2.47.0-pp311-pypy311_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:dbfb8f8ac48a8fd5a9f83cdbba0a23796f3ff5f9421c735808ed134d3318046f"}, + {file = "pydantic_core-2.47.0-pp311-pypy311_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:4ded2d8dc3beb086623ab3256b81a0b51fc026c94eb363f480aec09b204a1cbd"}, + {file = "pydantic_core-2.47.0-pp311-pypy311_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:25f8aa59fd3f85651387c37b289cd8f0bbc8f632e7dad13c6b837759e78be88d"}, + {file = "pydantic_core-2.47.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:0ae0096729c64aa6155cef74e75d1945017b01cfb82f313a18d843fafc497476"}, + {file = "pydantic_core-2.47.0.tar.gz", hash = "sha256:422c1797a7864b2a9a996435aba92fe571fb80190f67a31edbc1ac040c7b51fe"}, ] [package.dependencies] @@ -1149,23 +1136,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/exhaustive/pydantic-v1-with-utils/src/seed/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/pydantic-v1-with-utils/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/pydantic-v1-with-utils/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/pydantic-v1-with-utils/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/pydantic-v1-wrapped/poetry.lock b/seed/python-sdk/exhaustive/pydantic-v1-wrapped/poetry.lock index 63fa2ad6c2c8..f2eeffef483a 100644 --- a/seed/python-sdk/exhaustive/pydantic-v1-wrapped/poetry.lock +++ b/seed/python-sdk/exhaustive/pydantic-v1-wrapped/poetry.lock @@ -977,132 +977,119 @@ email = ["email-validator (>=1.0.3)"] [[package]] name = "pydantic-core" -version = "2.46.4" +version = "2.47.0" description = "Core functionality for Pydantic validation and serialization" optional = false -python-versions = ">=3.9" +python-versions = ">=3.10" groups = ["main"] files = [ - {file = "pydantic_core-2.46.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:a396dcc17e5a0b164dbe026896245a4fa9ff402edca1dff0be3d53a517f74de4"}, - {file = "pydantic_core-2.46.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:da4b951fe36dc7c3a1ccb4e3cd1747c3542b8c9ceede8fc86cae054e764485f5"}, - {file = "pydantic_core-2.46.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb63e0198ca18aad131c089b9204c23079c3afa95487e561f4c522d519e55aba"}, - {file = "pydantic_core-2.46.4-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f47286a97f0bc9b8859519809077b91b2cefe4ae47fcbf5e466a009c1c5d742b"}, - {file = "pydantic_core-2.46.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:905a0ed8ea6f2d61c1738835f99b699348d7857379083e5fc497fa0c967a407c"}, - {file = "pydantic_core-2.46.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea793e075b70290d89d8142074262885d3f7da19634845135751bd6344f73b50"}, - {file = "pydantic_core-2.46.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:395aebd9183f9d112f569aeb5b2214d1a10a33bec8456447f7fbdfa51d38d4cd"}, - {file = "pydantic_core-2.46.4-cp310-cp310-manylinux_2_31_riscv64.whl", hash = "sha256:b078afbc25f3a1436c7a1d2cd3e322497ee99615ba97c563566fdf46aff1ee01"}, - {file = "pydantic_core-2.46.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f747929cf940cddb5b3668a390056ddd5ba2e5010615ea2dcf4f9c4f3ab8791d"}, - {file = "pydantic_core-2.46.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:daa27d92c36f24388fe3ad306b174781c747627f134452e4f128ea00ce1fe8c4"}, - {file = "pydantic_core-2.46.4-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:19e51f073cd3df251856a8a4189fbdf1de4012c3ebacfb1884f94f1eb406079f"}, - {file = "pydantic_core-2.46.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c1747f85cee84c26985853c6f3d9bd3e75da5212912443fa111c113b9c246f39"}, - {file = "pydantic_core-2.46.4-cp310-cp310-win32.whl", hash = "sha256:2f84c03c8607173d16b5a854ec68a2f9079ae03237a54fb506d13af47e1d018d"}, - {file = "pydantic_core-2.46.4-cp310-cp310-win_amd64.whl", hash = "sha256:8358a950c8909158e3df31538a7e4edc2d7265a7c54b47f0864d9e5bae9dcebf"}, - {file = "pydantic_core-2.46.4-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:0e96592440881c74a213e5ad528e2b24d3d4f940de2766bed9010ab1d9e51594"}, - {file = "pydantic_core-2.46.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e0d65b8c354be7fb5f720c3caa8bc940bc2d20ce749c8e06135f07f8ed95dd7c"}, - {file = "pydantic_core-2.46.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7bfb192b3f4b9e8a89b6277b6ce787564f62cfd272055f6e685726b111dc7826"}, - {file = "pydantic_core-2.46.4-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9037063db01f09b09e237c282b6792bd4da634b5402c4e7f0c61effed7701a04"}, - {file = "pydantic_core-2.46.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fc010ab034c8c7452522748bf937df58020d256ccae0874463d1f4d01758af8e"}, - {file = "pydantic_core-2.46.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8c5dac79fa1614d1e06ca695109c6105923bd9c7d1d6c918d4e637b7e6b32fd3"}, - {file = "pydantic_core-2.46.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9fa868638bf362d3d138ea55829cefb3d5f4b0d7f142234382a15e2485dbec4"}, - {file = "pydantic_core-2.46.4-cp311-cp311-manylinux_2_31_riscv64.whl", hash = "sha256:17299feefe090f2caa5b8e37222bb5f663e4935a8bfa6931d4102e5df1a9f398"}, - {file = "pydantic_core-2.46.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4c63ebc82684aa89d9a3bcbd13d515b3be44250dc68dd3bd81526c1cb31286c3"}, - {file = "pydantic_core-2.46.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:aaa2a54443eff1950ba5ddc6b6ccda0d9c84a364276a62f969bdf2a390650848"}, - {file = "pydantic_core-2.46.4-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:18e5ceec2ab67e6d5f1a9085e5a24c9c4e2ac4545730bfe668680bca05e555f3"}, - {file = "pydantic_core-2.46.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a0f62d0a58f4e7da165457e995725421e0064f2255d8eccebc49f41bbc23b109"}, - {file = "pydantic_core-2.46.4-cp311-cp311-win32.whl", hash = "sha256:041bde0a48fd37cf71cab1c9d56d3e8625a3793fef1f7dd232b3ff37e978ecda"}, - {file = "pydantic_core-2.46.4-cp311-cp311-win_amd64.whl", hash = "sha256:6f2eeda33a839975441c86a4119e1383c50b47faf0cbb5176985565c6bb02c33"}, - {file = "pydantic_core-2.46.4-cp311-cp311-win_arm64.whl", hash = "sha256:14f4c5d6db102bd796a627bbb3a17b4cf4574b9ae861d8b7c9a9661c6dd3362d"}, - {file = "pydantic_core-2.46.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:3245406455a5d98187ec35530fd772b1d799b26667980872c8d4614991e2c4a2"}, - {file = "pydantic_core-2.46.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:962ccbab7b642487b1d8b7df90ef677e03134cf1fd8880bf698649b22a69371f"}, - {file = "pydantic_core-2.46.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8233f2947cf85404441fd7e0085f53b10c93e0ee78611099b5c7237e36aacbf7"}, - {file = "pydantic_core-2.46.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3a233125ac121aa3ffba9a2b59edfc4a985a76092dc8279586ab4b71390875e7"}, - {file = "pydantic_core-2.46.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5b712b53160b79a5850310b912a5ef8e57e56947c8ad690c227f5c9d7e561712"}, - {file = "pydantic_core-2.46.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9401557acd873c3a7f3eb9383edef8ac4968f9510e340f4808d427e75667e7b4"}, - {file = "pydantic_core-2.46.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:926c9541b14b12b1681dca8a0b75feb510b06c6341b70a8e500c2fdcff837cce"}, - {file = "pydantic_core-2.46.4-cp312-cp312-manylinux_2_31_riscv64.whl", hash = "sha256:56cb4851bcaf3d117eddcef4fe66afd750a50274b0da8e22be256d10e5611987"}, - {file = "pydantic_core-2.46.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c68fcd102d71ea85c5b2dfac3f4f8476eff42a9e078fd5faefff6d145063536b"}, - {file = "pydantic_core-2.46.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b2f69dec1725e79a012d920df1707de5caf7ed5e08f3be4435e25803efc47458"}, - {file = "pydantic_core-2.46.4-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:8d0820e8192167f80d88d64038e609c31452eeca865b4e1d9950a27a4609b00b"}, - {file = "pydantic_core-2.46.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:fbdb89b3e1c94a30cc5edfce477c6e6a5dc4d8f84665b455c27582f211a1c72c"}, - {file = "pydantic_core-2.46.4-cp312-cp312-win32.whl", hash = "sha256:9aa768456404a8bf48a4406685ac2bec8e72b62c69313734fa3b73cf33b3a894"}, - {file = "pydantic_core-2.46.4-cp312-cp312-win_amd64.whl", hash = "sha256:e9c26f834c65f5752f3f06cb08cb86a913ceb7274d0db6e267808a708b46bc89"}, - {file = "pydantic_core-2.46.4-cp312-cp312-win_arm64.whl", hash = "sha256:4fc73cb559bdb54b1134a706a2802a4cddd27a0633f5abb7e53056268751ac6a"}, - {file = "pydantic_core-2.46.4-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:5d5902252db0d3cedf8d4a1bc68f70eeb430f7e4c7104c8c476753519b423008"}, - {file = "pydantic_core-2.46.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c94f0688e7b8d0a67abf40e57a7eaaecd17cc9586706a31b76c031f63df052b4"}, - {file = "pydantic_core-2.46.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f027324c56cd5406ca49c124b0db10e56c69064fec039acc571c29020cc87c76"}, - {file = "pydantic_core-2.46.4-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e739fee756ba1010f8bcccb534252e85a35fe45ae92c295a06059ce58b74ccd3"}, - {file = "pydantic_core-2.46.4-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9d56801be94b86a9da183e5f3766e6310752b99ff647e38b09a9500d88e46e76"}, - {file = "pydantic_core-2.46.4-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2412e734dcb48da14d4e4006b82b46b74f2518b8a26ee7e58c6844a6cd6d03c4"}, - {file = "pydantic_core-2.46.4-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9551187363ffc0de2a00b2e47c25aeaeb1020b69b668762966df15fc5659dd5a"}, - {file = "pydantic_core-2.46.4-cp313-cp313-manylinux_2_31_riscv64.whl", hash = "sha256:0186750b482eefa11d7f435892b09c5c606193ef3375bcf94aa00ae6bfb66262"}, - {file = "pydantic_core-2.46.4-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5855698a4856556d86e8e6cd8434bc3ac0314ee8e12089ae0e143f64c6256e4e"}, - {file = "pydantic_core-2.46.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:cbaf13819775b7f769bf4a1f066cb6df7a28d4480081a589828ef190226881cd"}, - {file = "pydantic_core-2.46.4-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:633147d34cf4550417f12e2b1a0383973bdf5cdfde212cb09e9a581cf10820be"}, - {file = "pydantic_core-2.46.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:82cf5301172168103724d49a1444d3378cb20cdee30b116a1bd6031236298a5d"}, - {file = "pydantic_core-2.46.4-cp313-cp313-win32.whl", hash = "sha256:9fa8ae11da9e2b3126c6426f147e0fba88d96d65921799bb30c6abd1cb2c97fb"}, - {file = "pydantic_core-2.46.4-cp313-cp313-win_amd64.whl", hash = "sha256:6b3ace8194b0e5204818c92802dcdca7fc6d88aabbb799d7c795540d9cd6d292"}, - {file = "pydantic_core-2.46.4-cp313-cp313-win_arm64.whl", hash = "sha256:184c081504d17f1c1066e430e117142b2c77d9448a97f7b65c6ac9fd9aee238d"}, - {file = "pydantic_core-2.46.4-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:428e04521a40150c85216fc8b85e8d39fece235a9cf5e383761238c7fa9b96fb"}, - {file = "pydantic_core-2.46.4-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:23ace664830ee0bfe014a0c7bc248b1f7f25ed7ad103852c317624a1083af462"}, - {file = "pydantic_core-2.46.4-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce5c1d2a8b27468f433ca974829c44060b8097eedc39933e3c206a90ee49c4a9"}, - {file = "pydantic_core-2.46.4-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7283d57845ecf5a163403eb0702dfc220cc4fbdd18919cb5ccea4f95ee1cdab4"}, - {file = "pydantic_core-2.46.4-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8daafc69c93ee8a0204506a3b6b30f586ef54028f52aeeeb5c4cfc5184fd5914"}, - {file = "pydantic_core-2.46.4-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd2213145bcc2ba85884d0ac63d222fece9209678f77b9b4d76f054c561adb28"}, - {file = "pydantic_core-2.46.4-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a5f930472650a82629163023e630d160863fce524c616f4e5186e5de9d9a49b"}, - {file = "pydantic_core-2.46.4-cp314-cp314-manylinux_2_31_riscv64.whl", hash = "sha256:c1b3f518abeca3aa13c712fd202306e145abf59a18b094a6bafb2d2bbf59192c"}, - {file = "pydantic_core-2.46.4-cp314-cp314-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1a7dd0b3ee80d90150e3495a3a13ac34dbcbfd4f012996a6a1d8900e91b5c0fb"}, - {file = "pydantic_core-2.46.4-cp314-cp314-musllinux_1_1_aarch64.whl", hash = "sha256:3fb702cd90b0446a3a1c5e470bfa0dd23c0233b676a9099ddcc964fa6ca13898"}, - {file = "pydantic_core-2.46.4-cp314-cp314-musllinux_1_1_armv7l.whl", hash = "sha256:b8458003118a712e66286df6a707db01c52c0f52f7db8e4a38f0da1d3b94fc4e"}, - {file = "pydantic_core-2.46.4-cp314-cp314-musllinux_1_1_x86_64.whl", hash = "sha256:372429a130e469c9cd698925ce5fc50940b7a1336b0d82038e63d5bbc4edc519"}, - {file = "pydantic_core-2.46.4-cp314-cp314-win32.whl", hash = "sha256:85bb3611ff1802f3ee7fdd7dbff26b56f343fb432d57a4728fdd49b6ef35e2f4"}, - {file = "pydantic_core-2.46.4-cp314-cp314-win_amd64.whl", hash = "sha256:811ff8e9c313ab425368bcbb36e5c4ebd7108c2bbf4e4089cfbb0b01eff63fac"}, - {file = "pydantic_core-2.46.4-cp314-cp314-win_arm64.whl", hash = "sha256:bfec22eab3c8cc2ceec0248aec886624116dc079afa027ecc8ad4a7e62010f8a"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:af8244b2bef6aaad6d92cda81372de7f8c8d36c9f0c3ea36e827c60e7d9467a0"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:5a4330cdbc57162e4b3aa303f588ba752257694c9c9be3e7ebb11b4aca659b5d"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29c61fc04a3d840155ff08e475a04809278972fe6aef51e2720554e96367e34b"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c50f2528cf200c5eed56faf3f4e22fcd5f38c157a8b78576e6ba3168ec35f000"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0cbe8b01f948de4286c74cdd6c667aceb38f5c1e26f0693b3983d9d74887c65e"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:617d7e2ca7dcb8c5cf6bcb8c59b8832c94b36196bbf1cbd1bfb56ed341905edd"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7027560ee92211647d0d34e3f7cd6f50da56399d26a9c8ad0da286d3869a53f3"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-manylinux_2_31_riscv64.whl", hash = "sha256:f99626688942fb746e545232e7726926f3be91b5975f8b55327665fafda991c7"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:fc3e9034a63de20e15e8ade85358bc6efc614008cab72898b4b4952bea0509ff"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-musllinux_1_1_aarch64.whl", hash = "sha256:97e7cf2be5c77b7d1a9713a05605d49460d02c6078d38d8bef3cbe323c548424"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-musllinux_1_1_armv7l.whl", hash = "sha256:3bf92c5d0e00fefaab325a4d27828fe6b6e2a21848686b5b60d2d9eeb09d76c6"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-musllinux_1_1_x86_64.whl", hash = "sha256:3ecbc122d18468d06ca279dc26a8c2e2d5acb10943bb35e36ae92096dc3b5565"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-win32.whl", hash = "sha256:e846ae7835bf0703ae43f534ab79a867146dadd59dc9ca5c8b53d5c8f7c9ef02"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-win_amd64.whl", hash = "sha256:2108ba5c1c1eca18030634489dc544844144ee36357f2f9f780b93e7ddbb44b5"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-win_arm64.whl", hash = "sha256:4fcbe087dbc2068af7eda3aa87634eba216dbda64d1ae73c8684b621d33f6596"}, - {file = "pydantic_core-2.46.4-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:fd8b3d9fd264be37976686c7f65cd52a83f5e84f4bfd2adf9c1d469676bbb6ae"}, - {file = "pydantic_core-2.46.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9f444c499b3eefd3a92e348059471ea0c3a6e303d9c1cec09fa748fd9f895201"}, - {file = "pydantic_core-2.46.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3447661d99f75a3683a4cf5c87da72f2161964611864dbbeac7fbb118bb4bfc0"}, - {file = "pydantic_core-2.46.4-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8b9bab013d1c7a79d3501ff86d0bc9c31bf587db4551677b96bec07df78c6b15"}, - {file = "pydantic_core-2.46.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d995260fdf4e1db774581b4900e0f832abe3c7c84996726bbc161b19c8f29e76"}, - {file = "pydantic_core-2.46.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f13a646d65d09fbf1bc6b3a9635d30095c8e7e5cc419ff35ecc563c5fd04cd49"}, - {file = "pydantic_core-2.46.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:432c179df7874eeb73307aad2df0755e1ae0efa61ff0ea89b93e194411ae3928"}, - {file = "pydantic_core-2.46.4-cp39-cp39-manylinux_2_31_riscv64.whl", hash = "sha256:e68b7a074f65a2fd746c52a7ce6142ab7006074ac269ace0c25cd8ba171f8066"}, - {file = "pydantic_core-2.46.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4a05d69cba51d852c5c3e92758653245a50c0b646ced0cf05bd793ed592839d6"}, - {file = "pydantic_core-2.46.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:228ee9bae8bef5b1e97ec58302f80357c37199e0d0a99174e138d28e6957b9d9"}, - {file = "pydantic_core-2.46.4-cp39-cp39-musllinux_1_1_armv7l.whl", hash = "sha256:10e17cbb10a330363733efc4d7c4d0dd827ac0909b8f6a6542298fed1ea62f29"}, - {file = "pydantic_core-2.46.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:91a06d2e259ecfbd8c901d70c3c507900458498142b3026a296b7de4d1322cc9"}, - {file = "pydantic_core-2.46.4-cp39-cp39-win32.whl", hash = "sha256:d80ee3d731373b24cebbc10d689ca4ee1875caf0d5703a245db18efd4dd37fc1"}, - {file = "pydantic_core-2.46.4-cp39-cp39-win_amd64.whl", hash = "sha256:3be77f45df024d789a672ae34f8b06fb346c4f9f46ea714956660ea4862e89ac"}, - {file = "pydantic_core-2.46.4-graalpy311-graalpy242_311_native-macosx_10_12_x86_64.whl", hash = "sha256:14d4edf427bdcf950a8a02d7cb44a08614388dd6e1bdcbf4f67504fa7887da9c"}, - {file = "pydantic_core-2.46.4-graalpy311-graalpy242_311_native-macosx_11_0_arm64.whl", hash = "sha256:0ce40cd7b21210e99342afafbd4d0f76d784eb5b1d60f3bdc566be4983c6c73b"}, - {file = "pydantic_core-2.46.4-graalpy311-graalpy242_311_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:90884113d8b48f760e9587002789ddd741e76ab9f89518cd1e43b1f1a52ec44b"}, - {file = "pydantic_core-2.46.4-graalpy311-graalpy242_311_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66ce7632c22d837c95301830e111ad0128a32b8207533b60896a96c4915192ea"}, - {file = "pydantic_core-2.46.4-graalpy312-graalpy250_312_native-macosx_10_12_x86_64.whl", hash = "sha256:1d8ba486450b14f3b1d63bc521d410ec7565e52f887b9fb671791886436a42f7"}, - {file = "pydantic_core-2.46.4-graalpy312-graalpy250_312_native-macosx_11_0_arm64.whl", hash = "sha256:3009f12e4e90b7f88b4f9adb1b0c4a3d58fe7820f3238c190047209d148026df"}, - {file = "pydantic_core-2.46.4-graalpy312-graalpy250_312_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad785e92e6dc634c21555edc8bd6b64957ab844541bcb96a1366c202951ae526"}, - {file = "pydantic_core-2.46.4-graalpy312-graalpy250_312_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:00c603d540afdd6b80eb39f078f33ebd46211f02f33e34a32d9f053bba711de0"}, - {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:0c563b08bca408dc7f65f700633d8442fffb2421fc47b8101377e9fd65051ff0"}, - {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:db06ffe51636ffe9ca531fe9023dd64bdd794be8754cb5df57c5498ae5b518a7"}, - {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:133878133d271ade3d41d1bfb2a45ec38dbdbda40bc065921c6b04e4630127e2"}, - {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9bc519fbf2b7578398853d815009ae5e4d4603d12f4e3f91da8c06852d3da3e9"}, - {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:c7a7bd4e39e8e4c12c39cd480356842b6a8a06e41b23a55a5e3e191718838ddf"}, - {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:d396ec2b979760aaf3218e76c24e65bd0aca24983298653b3a9d7a45f9e47b30"}, - {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:86e1a4418c6cd97d60c95c71164158eaf7324fae7b0923264016baa993eba6fc"}, - {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:d51026d73fcfd93610abc7b27789c26b313920fcfb20e27462d74a7f8b06e983"}, - {file = "pydantic_core-2.46.4.tar.gz", hash = "sha256:62f875393d7f270851f20523dd2e29f082bcc82292d66db2b64ea71f64b6e1c1"}, + {file = "pydantic_core-2.47.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:d4c7148fc6c0bb727139010e15aab198be6c5d00276f83246b417ce69831f9d6"}, + {file = "pydantic_core-2.47.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:497f91a0499fa4ce7ae982756f8a237af19f145d944258c0c991cfb78aee13b0"}, + {file = "pydantic_core-2.47.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6cd3cde901878cce06787608a50c4456b8ad49c2128440c24b96b5624d26937"}, + {file = "pydantic_core-2.47.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6b02b17e44bcb066b9f3a1d31c0be01a59f81d0b94b5066fdded1a8ccc8f819e"}, + {file = "pydantic_core-2.47.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e42feab6c93fa3264708502f5062147073c7e57bc56bc1c44ed00efa53bc1859"}, + {file = "pydantic_core-2.47.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a3c450c65dfd14eee570756f61c823f8a7a36a3f4f4d46a3945d6225dc8a47d8"}, + {file = "pydantic_core-2.47.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0744954b81a77cfd381760d8b1bf92ba4db57d2da235e695d6fd3c94f741d24d"}, + {file = "pydantic_core-2.47.0-cp310-cp310-manylinux_2_31_riscv64.whl", hash = "sha256:34336a9cdd8b54e0cbc9d3c36b7be7dd828fa10e4209a7e4b0ca4583aa3e696d"}, + {file = "pydantic_core-2.47.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8623f4e598c9cee799075d0f1ab5174f9f2e7c42b3c5c7d859a97b3c726c84f8"}, + {file = "pydantic_core-2.47.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6143463422e3851187657ff677af0cb04c5e1a2c51c028438ce5f20fbb5cb50d"}, + {file = "pydantic_core-2.47.0-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:152bcbfe9f087716d185f6003be549f2cc6ee3cd4ca67909118e31626afc209b"}, + {file = "pydantic_core-2.47.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8ad7522d6e0cf28192b30f4c9db62e9b7a13ef10652bf8310de27bcb4a6c1c40"}, + {file = "pydantic_core-2.47.0-cp310-cp310-win32.whl", hash = "sha256:e1c8ea447dcfaa7f7d815d07bddb131383275682601878b5711f59fac68045a2"}, + {file = "pydantic_core-2.47.0-cp310-cp310-win_amd64.whl", hash = "sha256:df82086e6efb002a8e4f8f787dd2ddf9db46403fe8697b7620111663799b62b8"}, + {file = "pydantic_core-2.47.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:fe87ccbc39a103709d0a5afa75240c15a94611af129261e9484bef0bc97960b2"}, + {file = "pydantic_core-2.47.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:45882c24324f123037982c65eb8d60da778447e6bc87c82241f81d6c6d2c307e"}, + {file = "pydantic_core-2.47.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:983e39b39547772543f3518557d0a86dbb3b7bb58bf8e82faba1e0cfa3e816d9"}, + {file = "pydantic_core-2.47.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3eb92447d3a079b945b61b8cbd6c3ec2954de3655c4efa0ebd35b069e472c2a9"}, + {file = "pydantic_core-2.47.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3af9600c680bec4b8d23c32ddaf7a5d91ed39a2cf758c082e34e860140cdcd87"}, + {file = "pydantic_core-2.47.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9d6098342d4510a9034a500a53b1d737daf9cfc18a47cd21047d02d7d1587557"}, + {file = "pydantic_core-2.47.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9a891c20be5110deb1904f639f3615ec5022b3495995850d1abe7b8fa1550b5"}, + {file = "pydantic_core-2.47.0-cp311-cp311-manylinux_2_31_riscv64.whl", hash = "sha256:13990d357a50078e382b15fa3ce3f08043223b4be3eaeb340b184f54c1a2397a"}, + {file = "pydantic_core-2.47.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a0f7343305387bb5884f24d384b7978ad099a277b27529e592c041a502a37c32"}, + {file = "pydantic_core-2.47.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:2c5e11bb1be2de2707c9367f364e73091ef30d34be54b3a4564d7421ac1a16bd"}, + {file = "pydantic_core-2.47.0-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:e1b52aa981e034896712460c899ee30707c8c6a385e79bb7648aac76c748a3da"}, + {file = "pydantic_core-2.47.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d57d46021c20d4efc28f69b4ca4670dffbb7bdafa51d1967b747849726ec643b"}, + {file = "pydantic_core-2.47.0-cp311-cp311-win32.whl", hash = "sha256:d93c02ae8bd33f73624319d85cba47e754155c5bf104c0c5ca96fcd1f3094939"}, + {file = "pydantic_core-2.47.0-cp311-cp311-win_amd64.whl", hash = "sha256:859ca679f00e5feb11b58b616eb7bc0efbb13654be21f5c898e510e27671c900"}, + {file = "pydantic_core-2.47.0-cp311-cp311-win_arm64.whl", hash = "sha256:482667e5b7a3e97b0836f33a716199c4ec6ba9c896ca4db6eae799ec527c1e64"}, + {file = "pydantic_core-2.47.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:70a7aeba54854f5d97da65cb1a61f000f53df3704cab41cd81d65ab127ddd031"}, + {file = "pydantic_core-2.47.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c0e97329e38228f57fe1f2d91ba0ef39cc75cc1a84fe6ef58942d2fc6cb406bf"}, + {file = "pydantic_core-2.47.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:faf83e50714837af72f13e9369c50377552a4a74049d4477bed51c7e5822d94b"}, + {file = "pydantic_core-2.47.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f50abe60347bf8afe2b2f58db86cf3ac6e418eec7ffa01d9dc90ba29fc64f243"}, + {file = "pydantic_core-2.47.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f13b18e7dec056336f29ee77dea3cc5db0271d6215cac7249cc5c61b0a49d293"}, + {file = "pydantic_core-2.47.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3aaabbdbbbb8dff33fa053ffb2c980f39dec745fc03592f50e1e010449129841"}, + {file = "pydantic_core-2.47.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d2907ee6d15cf26787bfbeb4c42e18e52f358086eab91baea961a0d909248d6"}, + {file = "pydantic_core-2.47.0-cp312-cp312-manylinux_2_31_riscv64.whl", hash = "sha256:c413761260967f4dbb51135e1b49f30a1c29e15bf371fbb39754ed6475739545"}, + {file = "pydantic_core-2.47.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2bb527ac6e5a9721023b24615e1a55c01f47c60007f08b7d2afb89ff9c7a0e22"}, + {file = "pydantic_core-2.47.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b3abf7fd5e6abe483a63413f9cd26b7c93c20780e19c8556434c7279f6b2f10c"}, + {file = "pydantic_core-2.47.0-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:9cefe43d17a5a273d71697c084d3787defa7f578cd5fab4cef4c66d13c9e44b2"}, + {file = "pydantic_core-2.47.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:38a8b5371b7938e4f6c31060dbd51d3b3229aef7c43eaac2eb8b153e43c3f189"}, + {file = "pydantic_core-2.47.0-cp312-cp312-win32.whl", hash = "sha256:b87e95e644df2a36bd631dd0d6e097aa73d19a55adf7b1724ebdeece3d9c76b7"}, + {file = "pydantic_core-2.47.0-cp312-cp312-win_amd64.whl", hash = "sha256:a0078c5695322050ccedbc86eafaf3e2548439782c51d99e575de0e31b9fe4f4"}, + {file = "pydantic_core-2.47.0-cp312-cp312-win_arm64.whl", hash = "sha256:7eedc31996e9eba3bdfbbc380805ac6d765c889b7e93b17cd00ecb0200fd6dca"}, + {file = "pydantic_core-2.47.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:263560ece98bffbbc0a8047ce60b8a278c859db6a2a4e30d9454b02891045eca"}, + {file = "pydantic_core-2.47.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a7fdaff39a66bf66e9037da482575513d2f20bfb02ea9d9222b5cb3b902fc695"}, + {file = "pydantic_core-2.47.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b3e6c8ee5ce8c270bfae09763ae4bbbccfe81090c97d670a621fb86cb1ef6042"}, + {file = "pydantic_core-2.47.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e38cdae682cfec4b3816722dccf6376ca59049726d57dca83c2fe7cc13665589"}, + {file = "pydantic_core-2.47.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9fc3193ff0b7e2e168f84c6185e70475738c191f3154e0af8f897cd0f8f9a489"}, + {file = "pydantic_core-2.47.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:57ff41672a615f38af528ee904602be51c653248354e5db8e9252668abe91e68"}, + {file = "pydantic_core-2.47.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:473b9a2b2a1f0dd55cbb32d2b902f93babe7f141a0bb48fb4d3d4d2b3e93e9a0"}, + {file = "pydantic_core-2.47.0-cp313-cp313-manylinux_2_31_riscv64.whl", hash = "sha256:195f9c4ac43a7b2a044a7b86631c3352abdb820bed2823ea29f98f779255f459"}, + {file = "pydantic_core-2.47.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6c1cd10d39ef1ff8bcd68b6865bee9c434631ac0608d402fe86e678851c2e2a5"}, + {file = "pydantic_core-2.47.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:7cbe66352fe2b39511d49150e5b52159429cd21f5633a3e801dd2c43829dcdca"}, + {file = "pydantic_core-2.47.0-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:e35192a1d53e55d510d8bb1023c988c7cdae6d94539074971741b2a7656e49a1"}, + {file = "pydantic_core-2.47.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:81de54576de2e20baec76cc5afae2820f9049e6fdc4f357bac3391da02d0ba97"}, + {file = "pydantic_core-2.47.0-cp313-cp313-win32.whl", hash = "sha256:05da6647bdfd3888936ac10aa39b239d659f3c93dff281af0fc5943eb55629dd"}, + {file = "pydantic_core-2.47.0-cp313-cp313-win_amd64.whl", hash = "sha256:021220e0a03b66112737ee1fc49759340ce8fafb8d9ade1b7fb366b06033fa45"}, + {file = "pydantic_core-2.47.0-cp313-cp313-win_arm64.whl", hash = "sha256:55156ee2f6f561ea4e25ab55f84bd70b9c9ed2546a834cb2b038fe10225aaa37"}, + {file = "pydantic_core-2.47.0-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:6e37a6974fbd8fa7cae12285a76970d50b3689ffd6ed7c7fdd176ba81dd22d0e"}, + {file = "pydantic_core-2.47.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:2433b8524785cc117e602233bc574879bc8d87f09523edeec51665d5c46cf42d"}, + {file = "pydantic_core-2.47.0-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f8f9f2be064c8bf1189f46f7062fd42765d94f59cfb7db7ef8db19563192110a"}, + {file = "pydantic_core-2.47.0-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1f0a9659a2eb161573418e3138f616101ba21bbd2ff04916dca7b6712155e015"}, + {file = "pydantic_core-2.47.0-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2995074b99242aa28991e0120a3c881babc139e08750a05b7ea7d140644e091d"}, + {file = "pydantic_core-2.47.0-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c5b224dc04c3ff9b08c24419464eb7f6ad7a1049e12284a00bf80df82bd15fdb"}, + {file = "pydantic_core-2.47.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:264361b7236d4374fef6342908f87d084a0d58a2f8d0811e99f714309cb0ba7e"}, + {file = "pydantic_core-2.47.0-cp314-cp314-manylinux_2_31_riscv64.whl", hash = "sha256:53368beaf693f6302a6e33bdefe950857534a04d282811421bd20176d0fb5636"}, + {file = "pydantic_core-2.47.0-cp314-cp314-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5d2177b44ba7d9d86850f865f362feeaac6a2ed8517a9b505b97ff0b7fdbd7dd"}, + {file = "pydantic_core-2.47.0-cp314-cp314-musllinux_1_1_aarch64.whl", hash = "sha256:fb57a6b538ec7a01b937986bc093aec530fb056135b6bc9cfdd0bf8460c25bc2"}, + {file = "pydantic_core-2.47.0-cp314-cp314-musllinux_1_1_armv7l.whl", hash = "sha256:18c9c7c3a18e9bdbf1215d913f6bd00e17595dc92949817935cb87a3cf5f1697"}, + {file = "pydantic_core-2.47.0-cp314-cp314-musllinux_1_1_x86_64.whl", hash = "sha256:2f72a382886ca85bb1247303b9134cf9978c9d454de62e710a1ebcd9d2131927"}, + {file = "pydantic_core-2.47.0-cp314-cp314-pyemscripten_2026_0_wasm32.whl", hash = "sha256:cdf4dc2cdd0eacad1bd81c4d25422b4c25b206acae095d2d64e5d5cb7facc6b3"}, + {file = "pydantic_core-2.47.0-cp314-cp314-win32.whl", hash = "sha256:1e859dd5e06e9807080e14995db131649a77c61131cc464a7fe492a69ce82488"}, + {file = "pydantic_core-2.47.0-cp314-cp314-win_amd64.whl", hash = "sha256:234ecade0e358caa1ea516c218b3f61e61e30532cad1a8bb12f2487325838548"}, + {file = "pydantic_core-2.47.0-cp314-cp314-win_arm64.whl", hash = "sha256:58158d0111e86893bc35aacabe509f951ed303cddf8cdba43533190bde317914"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:24017fd3befd4d7cc4a8c71f4a1e9a44d29fdc91723c5446b0e795ab808adee7"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:dfa0820888cc4549fcce7e6bd8affa7d75198d885ccd0bb0760def4bb8461862"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1d507f331756f7066cde7c9f35ed78fa78223a54369dbc8a34d6da7a5074fa1"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5c4a885fb50c05903bc703a00830616e680304fdcdd90fc9535a52e72debe712"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7b107cbd764bf68f12a57c7aa5846d868bc7463490a1ac2d0f19bffef624c5a9"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0b5f46412c8226d0c8f1f423c324c75afe342e4b854836933579fb484f68598c"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90cbf7e35d597503cbdb5cd85409cbb75f377290bc7e8e37cc5dfe4f5cc66cf8"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-manylinux_2_31_riscv64.whl", hash = "sha256:04ee7ba7172cb4484af51b2890f19069d35773698abc8c6ebb651f52fbf41134"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2a7acb3e120ddba94372b8146e62ea3a0bce203180e34641c817c87f995c91e0"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-musllinux_1_1_aarch64.whl", hash = "sha256:1a2f7ceb58013d167d8c96f10ec9b3a137018c819ba356f68ff1cb74302fd22e"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-musllinux_1_1_armv7l.whl", hash = "sha256:482b097637fec5037eb13fe9f9d5fe47e568a5b451d686bc2b854076cc0b50ca"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-musllinux_1_1_x86_64.whl", hash = "sha256:a8c7f5fad73eb404f4b84c75f3d9d3865b748ded248b7366341db6e516fc502b"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-win32.whl", hash = "sha256:f343c39928097175acf2f7d0cba5c00b0f62265d88a173a4ce264266ac849bd9"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-win_amd64.whl", hash = "sha256:52d40e074da44e42b2425aedebd513e405a31807036ef597175117a9b01743a6"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-win_arm64.whl", hash = "sha256:25cac08c9735e61e5c0b9f7a85c438661fc0e9da226afdfa984c3da6c5942a5a"}, + {file = "pydantic_core-2.47.0-graalpy311-graalpy242_311_native-macosx_10_12_x86_64.whl", hash = "sha256:dbe50ffd50c7e1b8e3bedafafd10fabe8a7cf51916f995fd57316eb1293f2439"}, + {file = "pydantic_core-2.47.0-graalpy311-graalpy242_311_native-macosx_11_0_arm64.whl", hash = "sha256:67c3f4b57a1cda846a9b6258e0ef70f99b0757d0b6a55f7e3e5b100620937d2c"}, + {file = "pydantic_core-2.47.0-graalpy311-graalpy242_311_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:33c1e025c83b3987784191ec5aa78cfb94686f1ba73d98ae4d520c09cdecda8d"}, + {file = "pydantic_core-2.47.0-graalpy311-graalpy242_311_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c547cac87585ab8a6e9495fd3be22ae8ccaabb3c909f6ae589a180677c105cc8"}, + {file = "pydantic_core-2.47.0-graalpy312-graalpy250_312_native-macosx_10_12_x86_64.whl", hash = "sha256:d648515c65f8871ceca6a0446be32fb1a0aae414db3907ec0df6cc2380dd0c04"}, + {file = "pydantic_core-2.47.0-graalpy312-graalpy250_312_native-macosx_11_0_arm64.whl", hash = "sha256:3292813ac8ed17671a5b1ac8256b6e98b96c29c1c6d061c35899e2f6e0444cac"}, + {file = "pydantic_core-2.47.0-graalpy312-graalpy250_312_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf3cfcea028f41a9b42a47f4a53d72ca4274d04e3ee525bd58ca89ea8c5e1910"}, + {file = "pydantic_core-2.47.0-graalpy312-graalpy250_312_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:52afdd8640d59890539060e91c8a494cf96b463e63b2ba499cc5c23abcb5e96b"}, + {file = "pydantic_core-2.47.0-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:ff201e8b93cd6a3849ec3fecb1d642d0391ca4a387f1162adaadddfd1986598c"}, + {file = "pydantic_core-2.47.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:eba39fca6e1fc7c58a5e6b5c320c6eec86014f7075dfa0fb1a79076e00304b3a"}, + {file = "pydantic_core-2.47.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5b494566e8e7e1a3992a550900902d6c0051538e53f5957e24003e7775638a24"}, + {file = "pydantic_core-2.47.0-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c96f745745fe778a92e74f40675c57b80dd46fc98186c9020e5ba4514a0470bd"}, + {file = "pydantic_core-2.47.0-pp311-pypy311_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:dbfb8f8ac48a8fd5a9f83cdbba0a23796f3ff5f9421c735808ed134d3318046f"}, + {file = "pydantic_core-2.47.0-pp311-pypy311_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:4ded2d8dc3beb086623ab3256b81a0b51fc026c94eb363f480aec09b204a1cbd"}, + {file = "pydantic_core-2.47.0-pp311-pypy311_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:25f8aa59fd3f85651387c37b289cd8f0bbc8f632e7dad13c6b837759e78be88d"}, + {file = "pydantic_core-2.47.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:0ae0096729c64aa6155cef74e75d1945017b01cfb82f313a18d843fafc497476"}, + {file = "pydantic_core-2.47.0.tar.gz", hash = "sha256:422c1797a7864b2a9a996435aba92fe571fb80190f67a31edbc1ac040c7b51fe"}, ] [package.dependencies] @@ -1149,23 +1136,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/exhaustive/pydantic-v1-wrapped/src/seed/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/pydantic-v1-wrapped/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/pydantic-v1-wrapped/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/pydantic-v1-wrapped/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/pydantic-v1/poetry.lock b/seed/python-sdk/exhaustive/pydantic-v1/poetry.lock index 63fa2ad6c2c8..f2eeffef483a 100644 --- a/seed/python-sdk/exhaustive/pydantic-v1/poetry.lock +++ b/seed/python-sdk/exhaustive/pydantic-v1/poetry.lock @@ -977,132 +977,119 @@ email = ["email-validator (>=1.0.3)"] [[package]] name = "pydantic-core" -version = "2.46.4" +version = "2.47.0" description = "Core functionality for Pydantic validation and serialization" optional = false -python-versions = ">=3.9" +python-versions = ">=3.10" groups = ["main"] files = [ - {file = "pydantic_core-2.46.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:a396dcc17e5a0b164dbe026896245a4fa9ff402edca1dff0be3d53a517f74de4"}, - {file = "pydantic_core-2.46.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:da4b951fe36dc7c3a1ccb4e3cd1747c3542b8c9ceede8fc86cae054e764485f5"}, - {file = "pydantic_core-2.46.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb63e0198ca18aad131c089b9204c23079c3afa95487e561f4c522d519e55aba"}, - {file = "pydantic_core-2.46.4-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f47286a97f0bc9b8859519809077b91b2cefe4ae47fcbf5e466a009c1c5d742b"}, - {file = "pydantic_core-2.46.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:905a0ed8ea6f2d61c1738835f99b699348d7857379083e5fc497fa0c967a407c"}, - {file = "pydantic_core-2.46.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea793e075b70290d89d8142074262885d3f7da19634845135751bd6344f73b50"}, - {file = "pydantic_core-2.46.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:395aebd9183f9d112f569aeb5b2214d1a10a33bec8456447f7fbdfa51d38d4cd"}, - {file = "pydantic_core-2.46.4-cp310-cp310-manylinux_2_31_riscv64.whl", hash = "sha256:b078afbc25f3a1436c7a1d2cd3e322497ee99615ba97c563566fdf46aff1ee01"}, - {file = "pydantic_core-2.46.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f747929cf940cddb5b3668a390056ddd5ba2e5010615ea2dcf4f9c4f3ab8791d"}, - {file = "pydantic_core-2.46.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:daa27d92c36f24388fe3ad306b174781c747627f134452e4f128ea00ce1fe8c4"}, - {file = "pydantic_core-2.46.4-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:19e51f073cd3df251856a8a4189fbdf1de4012c3ebacfb1884f94f1eb406079f"}, - {file = "pydantic_core-2.46.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c1747f85cee84c26985853c6f3d9bd3e75da5212912443fa111c113b9c246f39"}, - {file = "pydantic_core-2.46.4-cp310-cp310-win32.whl", hash = "sha256:2f84c03c8607173d16b5a854ec68a2f9079ae03237a54fb506d13af47e1d018d"}, - {file = "pydantic_core-2.46.4-cp310-cp310-win_amd64.whl", hash = "sha256:8358a950c8909158e3df31538a7e4edc2d7265a7c54b47f0864d9e5bae9dcebf"}, - {file = "pydantic_core-2.46.4-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:0e96592440881c74a213e5ad528e2b24d3d4f940de2766bed9010ab1d9e51594"}, - {file = "pydantic_core-2.46.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e0d65b8c354be7fb5f720c3caa8bc940bc2d20ce749c8e06135f07f8ed95dd7c"}, - {file = "pydantic_core-2.46.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7bfb192b3f4b9e8a89b6277b6ce787564f62cfd272055f6e685726b111dc7826"}, - {file = "pydantic_core-2.46.4-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9037063db01f09b09e237c282b6792bd4da634b5402c4e7f0c61effed7701a04"}, - {file = "pydantic_core-2.46.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fc010ab034c8c7452522748bf937df58020d256ccae0874463d1f4d01758af8e"}, - {file = "pydantic_core-2.46.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8c5dac79fa1614d1e06ca695109c6105923bd9c7d1d6c918d4e637b7e6b32fd3"}, - {file = "pydantic_core-2.46.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9fa868638bf362d3d138ea55829cefb3d5f4b0d7f142234382a15e2485dbec4"}, - {file = "pydantic_core-2.46.4-cp311-cp311-manylinux_2_31_riscv64.whl", hash = "sha256:17299feefe090f2caa5b8e37222bb5f663e4935a8bfa6931d4102e5df1a9f398"}, - {file = "pydantic_core-2.46.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4c63ebc82684aa89d9a3bcbd13d515b3be44250dc68dd3bd81526c1cb31286c3"}, - {file = "pydantic_core-2.46.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:aaa2a54443eff1950ba5ddc6b6ccda0d9c84a364276a62f969bdf2a390650848"}, - {file = "pydantic_core-2.46.4-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:18e5ceec2ab67e6d5f1a9085e5a24c9c4e2ac4545730bfe668680bca05e555f3"}, - {file = "pydantic_core-2.46.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a0f62d0a58f4e7da165457e995725421e0064f2255d8eccebc49f41bbc23b109"}, - {file = "pydantic_core-2.46.4-cp311-cp311-win32.whl", hash = "sha256:041bde0a48fd37cf71cab1c9d56d3e8625a3793fef1f7dd232b3ff37e978ecda"}, - {file = "pydantic_core-2.46.4-cp311-cp311-win_amd64.whl", hash = "sha256:6f2eeda33a839975441c86a4119e1383c50b47faf0cbb5176985565c6bb02c33"}, - {file = "pydantic_core-2.46.4-cp311-cp311-win_arm64.whl", hash = "sha256:14f4c5d6db102bd796a627bbb3a17b4cf4574b9ae861d8b7c9a9661c6dd3362d"}, - {file = "pydantic_core-2.46.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:3245406455a5d98187ec35530fd772b1d799b26667980872c8d4614991e2c4a2"}, - {file = "pydantic_core-2.46.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:962ccbab7b642487b1d8b7df90ef677e03134cf1fd8880bf698649b22a69371f"}, - {file = "pydantic_core-2.46.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8233f2947cf85404441fd7e0085f53b10c93e0ee78611099b5c7237e36aacbf7"}, - {file = "pydantic_core-2.46.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3a233125ac121aa3ffba9a2b59edfc4a985a76092dc8279586ab4b71390875e7"}, - {file = "pydantic_core-2.46.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5b712b53160b79a5850310b912a5ef8e57e56947c8ad690c227f5c9d7e561712"}, - {file = "pydantic_core-2.46.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9401557acd873c3a7f3eb9383edef8ac4968f9510e340f4808d427e75667e7b4"}, - {file = "pydantic_core-2.46.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:926c9541b14b12b1681dca8a0b75feb510b06c6341b70a8e500c2fdcff837cce"}, - {file = "pydantic_core-2.46.4-cp312-cp312-manylinux_2_31_riscv64.whl", hash = "sha256:56cb4851bcaf3d117eddcef4fe66afd750a50274b0da8e22be256d10e5611987"}, - {file = "pydantic_core-2.46.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c68fcd102d71ea85c5b2dfac3f4f8476eff42a9e078fd5faefff6d145063536b"}, - {file = "pydantic_core-2.46.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b2f69dec1725e79a012d920df1707de5caf7ed5e08f3be4435e25803efc47458"}, - {file = "pydantic_core-2.46.4-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:8d0820e8192167f80d88d64038e609c31452eeca865b4e1d9950a27a4609b00b"}, - {file = "pydantic_core-2.46.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:fbdb89b3e1c94a30cc5edfce477c6e6a5dc4d8f84665b455c27582f211a1c72c"}, - {file = "pydantic_core-2.46.4-cp312-cp312-win32.whl", hash = "sha256:9aa768456404a8bf48a4406685ac2bec8e72b62c69313734fa3b73cf33b3a894"}, - {file = "pydantic_core-2.46.4-cp312-cp312-win_amd64.whl", hash = "sha256:e9c26f834c65f5752f3f06cb08cb86a913ceb7274d0db6e267808a708b46bc89"}, - {file = "pydantic_core-2.46.4-cp312-cp312-win_arm64.whl", hash = "sha256:4fc73cb559bdb54b1134a706a2802a4cddd27a0633f5abb7e53056268751ac6a"}, - {file = "pydantic_core-2.46.4-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:5d5902252db0d3cedf8d4a1bc68f70eeb430f7e4c7104c8c476753519b423008"}, - {file = "pydantic_core-2.46.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c94f0688e7b8d0a67abf40e57a7eaaecd17cc9586706a31b76c031f63df052b4"}, - {file = "pydantic_core-2.46.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f027324c56cd5406ca49c124b0db10e56c69064fec039acc571c29020cc87c76"}, - {file = "pydantic_core-2.46.4-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e739fee756ba1010f8bcccb534252e85a35fe45ae92c295a06059ce58b74ccd3"}, - {file = "pydantic_core-2.46.4-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9d56801be94b86a9da183e5f3766e6310752b99ff647e38b09a9500d88e46e76"}, - {file = "pydantic_core-2.46.4-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2412e734dcb48da14d4e4006b82b46b74f2518b8a26ee7e58c6844a6cd6d03c4"}, - {file = "pydantic_core-2.46.4-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9551187363ffc0de2a00b2e47c25aeaeb1020b69b668762966df15fc5659dd5a"}, - {file = "pydantic_core-2.46.4-cp313-cp313-manylinux_2_31_riscv64.whl", hash = "sha256:0186750b482eefa11d7f435892b09c5c606193ef3375bcf94aa00ae6bfb66262"}, - {file = "pydantic_core-2.46.4-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5855698a4856556d86e8e6cd8434bc3ac0314ee8e12089ae0e143f64c6256e4e"}, - {file = "pydantic_core-2.46.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:cbaf13819775b7f769bf4a1f066cb6df7a28d4480081a589828ef190226881cd"}, - {file = "pydantic_core-2.46.4-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:633147d34cf4550417f12e2b1a0383973bdf5cdfde212cb09e9a581cf10820be"}, - {file = "pydantic_core-2.46.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:82cf5301172168103724d49a1444d3378cb20cdee30b116a1bd6031236298a5d"}, - {file = "pydantic_core-2.46.4-cp313-cp313-win32.whl", hash = "sha256:9fa8ae11da9e2b3126c6426f147e0fba88d96d65921799bb30c6abd1cb2c97fb"}, - {file = "pydantic_core-2.46.4-cp313-cp313-win_amd64.whl", hash = "sha256:6b3ace8194b0e5204818c92802dcdca7fc6d88aabbb799d7c795540d9cd6d292"}, - {file = "pydantic_core-2.46.4-cp313-cp313-win_arm64.whl", hash = "sha256:184c081504d17f1c1066e430e117142b2c77d9448a97f7b65c6ac9fd9aee238d"}, - {file = "pydantic_core-2.46.4-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:428e04521a40150c85216fc8b85e8d39fece235a9cf5e383761238c7fa9b96fb"}, - {file = "pydantic_core-2.46.4-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:23ace664830ee0bfe014a0c7bc248b1f7f25ed7ad103852c317624a1083af462"}, - {file = "pydantic_core-2.46.4-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce5c1d2a8b27468f433ca974829c44060b8097eedc39933e3c206a90ee49c4a9"}, - {file = "pydantic_core-2.46.4-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7283d57845ecf5a163403eb0702dfc220cc4fbdd18919cb5ccea4f95ee1cdab4"}, - {file = "pydantic_core-2.46.4-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8daafc69c93ee8a0204506a3b6b30f586ef54028f52aeeeb5c4cfc5184fd5914"}, - {file = "pydantic_core-2.46.4-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd2213145bcc2ba85884d0ac63d222fece9209678f77b9b4d76f054c561adb28"}, - {file = "pydantic_core-2.46.4-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a5f930472650a82629163023e630d160863fce524c616f4e5186e5de9d9a49b"}, - {file = "pydantic_core-2.46.4-cp314-cp314-manylinux_2_31_riscv64.whl", hash = "sha256:c1b3f518abeca3aa13c712fd202306e145abf59a18b094a6bafb2d2bbf59192c"}, - {file = "pydantic_core-2.46.4-cp314-cp314-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1a7dd0b3ee80d90150e3495a3a13ac34dbcbfd4f012996a6a1d8900e91b5c0fb"}, - {file = "pydantic_core-2.46.4-cp314-cp314-musllinux_1_1_aarch64.whl", hash = "sha256:3fb702cd90b0446a3a1c5e470bfa0dd23c0233b676a9099ddcc964fa6ca13898"}, - {file = "pydantic_core-2.46.4-cp314-cp314-musllinux_1_1_armv7l.whl", hash = "sha256:b8458003118a712e66286df6a707db01c52c0f52f7db8e4a38f0da1d3b94fc4e"}, - {file = "pydantic_core-2.46.4-cp314-cp314-musllinux_1_1_x86_64.whl", hash = "sha256:372429a130e469c9cd698925ce5fc50940b7a1336b0d82038e63d5bbc4edc519"}, - {file = "pydantic_core-2.46.4-cp314-cp314-win32.whl", hash = "sha256:85bb3611ff1802f3ee7fdd7dbff26b56f343fb432d57a4728fdd49b6ef35e2f4"}, - {file = "pydantic_core-2.46.4-cp314-cp314-win_amd64.whl", hash = "sha256:811ff8e9c313ab425368bcbb36e5c4ebd7108c2bbf4e4089cfbb0b01eff63fac"}, - {file = "pydantic_core-2.46.4-cp314-cp314-win_arm64.whl", hash = "sha256:bfec22eab3c8cc2ceec0248aec886624116dc079afa027ecc8ad4a7e62010f8a"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:af8244b2bef6aaad6d92cda81372de7f8c8d36c9f0c3ea36e827c60e7d9467a0"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:5a4330cdbc57162e4b3aa303f588ba752257694c9c9be3e7ebb11b4aca659b5d"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29c61fc04a3d840155ff08e475a04809278972fe6aef51e2720554e96367e34b"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c50f2528cf200c5eed56faf3f4e22fcd5f38c157a8b78576e6ba3168ec35f000"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0cbe8b01f948de4286c74cdd6c667aceb38f5c1e26f0693b3983d9d74887c65e"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:617d7e2ca7dcb8c5cf6bcb8c59b8832c94b36196bbf1cbd1bfb56ed341905edd"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7027560ee92211647d0d34e3f7cd6f50da56399d26a9c8ad0da286d3869a53f3"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-manylinux_2_31_riscv64.whl", hash = "sha256:f99626688942fb746e545232e7726926f3be91b5975f8b55327665fafda991c7"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:fc3e9034a63de20e15e8ade85358bc6efc614008cab72898b4b4952bea0509ff"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-musllinux_1_1_aarch64.whl", hash = "sha256:97e7cf2be5c77b7d1a9713a05605d49460d02c6078d38d8bef3cbe323c548424"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-musllinux_1_1_armv7l.whl", hash = "sha256:3bf92c5d0e00fefaab325a4d27828fe6b6e2a21848686b5b60d2d9eeb09d76c6"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-musllinux_1_1_x86_64.whl", hash = "sha256:3ecbc122d18468d06ca279dc26a8c2e2d5acb10943bb35e36ae92096dc3b5565"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-win32.whl", hash = "sha256:e846ae7835bf0703ae43f534ab79a867146dadd59dc9ca5c8b53d5c8f7c9ef02"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-win_amd64.whl", hash = "sha256:2108ba5c1c1eca18030634489dc544844144ee36357f2f9f780b93e7ddbb44b5"}, - {file = "pydantic_core-2.46.4-cp314-cp314t-win_arm64.whl", hash = "sha256:4fcbe087dbc2068af7eda3aa87634eba216dbda64d1ae73c8684b621d33f6596"}, - {file = "pydantic_core-2.46.4-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:fd8b3d9fd264be37976686c7f65cd52a83f5e84f4bfd2adf9c1d469676bbb6ae"}, - {file = "pydantic_core-2.46.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9f444c499b3eefd3a92e348059471ea0c3a6e303d9c1cec09fa748fd9f895201"}, - {file = "pydantic_core-2.46.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3447661d99f75a3683a4cf5c87da72f2161964611864dbbeac7fbb118bb4bfc0"}, - {file = "pydantic_core-2.46.4-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8b9bab013d1c7a79d3501ff86d0bc9c31bf587db4551677b96bec07df78c6b15"}, - {file = "pydantic_core-2.46.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d995260fdf4e1db774581b4900e0f832abe3c7c84996726bbc161b19c8f29e76"}, - {file = "pydantic_core-2.46.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f13a646d65d09fbf1bc6b3a9635d30095c8e7e5cc419ff35ecc563c5fd04cd49"}, - {file = "pydantic_core-2.46.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:432c179df7874eeb73307aad2df0755e1ae0efa61ff0ea89b93e194411ae3928"}, - {file = "pydantic_core-2.46.4-cp39-cp39-manylinux_2_31_riscv64.whl", hash = "sha256:e68b7a074f65a2fd746c52a7ce6142ab7006074ac269ace0c25cd8ba171f8066"}, - {file = "pydantic_core-2.46.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4a05d69cba51d852c5c3e92758653245a50c0b646ced0cf05bd793ed592839d6"}, - {file = "pydantic_core-2.46.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:228ee9bae8bef5b1e97ec58302f80357c37199e0d0a99174e138d28e6957b9d9"}, - {file = "pydantic_core-2.46.4-cp39-cp39-musllinux_1_1_armv7l.whl", hash = "sha256:10e17cbb10a330363733efc4d7c4d0dd827ac0909b8f6a6542298fed1ea62f29"}, - {file = "pydantic_core-2.46.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:91a06d2e259ecfbd8c901d70c3c507900458498142b3026a296b7de4d1322cc9"}, - {file = "pydantic_core-2.46.4-cp39-cp39-win32.whl", hash = "sha256:d80ee3d731373b24cebbc10d689ca4ee1875caf0d5703a245db18efd4dd37fc1"}, - {file = "pydantic_core-2.46.4-cp39-cp39-win_amd64.whl", hash = "sha256:3be77f45df024d789a672ae34f8b06fb346c4f9f46ea714956660ea4862e89ac"}, - {file = "pydantic_core-2.46.4-graalpy311-graalpy242_311_native-macosx_10_12_x86_64.whl", hash = "sha256:14d4edf427bdcf950a8a02d7cb44a08614388dd6e1bdcbf4f67504fa7887da9c"}, - {file = "pydantic_core-2.46.4-graalpy311-graalpy242_311_native-macosx_11_0_arm64.whl", hash = "sha256:0ce40cd7b21210e99342afafbd4d0f76d784eb5b1d60f3bdc566be4983c6c73b"}, - {file = "pydantic_core-2.46.4-graalpy311-graalpy242_311_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:90884113d8b48f760e9587002789ddd741e76ab9f89518cd1e43b1f1a52ec44b"}, - {file = "pydantic_core-2.46.4-graalpy311-graalpy242_311_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66ce7632c22d837c95301830e111ad0128a32b8207533b60896a96c4915192ea"}, - {file = "pydantic_core-2.46.4-graalpy312-graalpy250_312_native-macosx_10_12_x86_64.whl", hash = "sha256:1d8ba486450b14f3b1d63bc521d410ec7565e52f887b9fb671791886436a42f7"}, - {file = "pydantic_core-2.46.4-graalpy312-graalpy250_312_native-macosx_11_0_arm64.whl", hash = "sha256:3009f12e4e90b7f88b4f9adb1b0c4a3d58fe7820f3238c190047209d148026df"}, - {file = "pydantic_core-2.46.4-graalpy312-graalpy250_312_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad785e92e6dc634c21555edc8bd6b64957ab844541bcb96a1366c202951ae526"}, - {file = "pydantic_core-2.46.4-graalpy312-graalpy250_312_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:00c603d540afdd6b80eb39f078f33ebd46211f02f33e34a32d9f053bba711de0"}, - {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:0c563b08bca408dc7f65f700633d8442fffb2421fc47b8101377e9fd65051ff0"}, - {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:db06ffe51636ffe9ca531fe9023dd64bdd794be8754cb5df57c5498ae5b518a7"}, - {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:133878133d271ade3d41d1bfb2a45ec38dbdbda40bc065921c6b04e4630127e2"}, - {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9bc519fbf2b7578398853d815009ae5e4d4603d12f4e3f91da8c06852d3da3e9"}, - {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:c7a7bd4e39e8e4c12c39cd480356842b6a8a06e41b23a55a5e3e191718838ddf"}, - {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:d396ec2b979760aaf3218e76c24e65bd0aca24983298653b3a9d7a45f9e47b30"}, - {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:86e1a4418c6cd97d60c95c71164158eaf7324fae7b0923264016baa993eba6fc"}, - {file = "pydantic_core-2.46.4-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:d51026d73fcfd93610abc7b27789c26b313920fcfb20e27462d74a7f8b06e983"}, - {file = "pydantic_core-2.46.4.tar.gz", hash = "sha256:62f875393d7f270851f20523dd2e29f082bcc82292d66db2b64ea71f64b6e1c1"}, + {file = "pydantic_core-2.47.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:d4c7148fc6c0bb727139010e15aab198be6c5d00276f83246b417ce69831f9d6"}, + {file = "pydantic_core-2.47.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:497f91a0499fa4ce7ae982756f8a237af19f145d944258c0c991cfb78aee13b0"}, + {file = "pydantic_core-2.47.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6cd3cde901878cce06787608a50c4456b8ad49c2128440c24b96b5624d26937"}, + {file = "pydantic_core-2.47.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6b02b17e44bcb066b9f3a1d31c0be01a59f81d0b94b5066fdded1a8ccc8f819e"}, + {file = "pydantic_core-2.47.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e42feab6c93fa3264708502f5062147073c7e57bc56bc1c44ed00efa53bc1859"}, + {file = "pydantic_core-2.47.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a3c450c65dfd14eee570756f61c823f8a7a36a3f4f4d46a3945d6225dc8a47d8"}, + {file = "pydantic_core-2.47.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0744954b81a77cfd381760d8b1bf92ba4db57d2da235e695d6fd3c94f741d24d"}, + {file = "pydantic_core-2.47.0-cp310-cp310-manylinux_2_31_riscv64.whl", hash = "sha256:34336a9cdd8b54e0cbc9d3c36b7be7dd828fa10e4209a7e4b0ca4583aa3e696d"}, + {file = "pydantic_core-2.47.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8623f4e598c9cee799075d0f1ab5174f9f2e7c42b3c5c7d859a97b3c726c84f8"}, + {file = "pydantic_core-2.47.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6143463422e3851187657ff677af0cb04c5e1a2c51c028438ce5f20fbb5cb50d"}, + {file = "pydantic_core-2.47.0-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:152bcbfe9f087716d185f6003be549f2cc6ee3cd4ca67909118e31626afc209b"}, + {file = "pydantic_core-2.47.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8ad7522d6e0cf28192b30f4c9db62e9b7a13ef10652bf8310de27bcb4a6c1c40"}, + {file = "pydantic_core-2.47.0-cp310-cp310-win32.whl", hash = "sha256:e1c8ea447dcfaa7f7d815d07bddb131383275682601878b5711f59fac68045a2"}, + {file = "pydantic_core-2.47.0-cp310-cp310-win_amd64.whl", hash = "sha256:df82086e6efb002a8e4f8f787dd2ddf9db46403fe8697b7620111663799b62b8"}, + {file = "pydantic_core-2.47.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:fe87ccbc39a103709d0a5afa75240c15a94611af129261e9484bef0bc97960b2"}, + {file = "pydantic_core-2.47.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:45882c24324f123037982c65eb8d60da778447e6bc87c82241f81d6c6d2c307e"}, + {file = "pydantic_core-2.47.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:983e39b39547772543f3518557d0a86dbb3b7bb58bf8e82faba1e0cfa3e816d9"}, + {file = "pydantic_core-2.47.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3eb92447d3a079b945b61b8cbd6c3ec2954de3655c4efa0ebd35b069e472c2a9"}, + {file = "pydantic_core-2.47.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3af9600c680bec4b8d23c32ddaf7a5d91ed39a2cf758c082e34e860140cdcd87"}, + {file = "pydantic_core-2.47.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9d6098342d4510a9034a500a53b1d737daf9cfc18a47cd21047d02d7d1587557"}, + {file = "pydantic_core-2.47.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9a891c20be5110deb1904f639f3615ec5022b3495995850d1abe7b8fa1550b5"}, + {file = "pydantic_core-2.47.0-cp311-cp311-manylinux_2_31_riscv64.whl", hash = "sha256:13990d357a50078e382b15fa3ce3f08043223b4be3eaeb340b184f54c1a2397a"}, + {file = "pydantic_core-2.47.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a0f7343305387bb5884f24d384b7978ad099a277b27529e592c041a502a37c32"}, + {file = "pydantic_core-2.47.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:2c5e11bb1be2de2707c9367f364e73091ef30d34be54b3a4564d7421ac1a16bd"}, + {file = "pydantic_core-2.47.0-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:e1b52aa981e034896712460c899ee30707c8c6a385e79bb7648aac76c748a3da"}, + {file = "pydantic_core-2.47.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d57d46021c20d4efc28f69b4ca4670dffbb7bdafa51d1967b747849726ec643b"}, + {file = "pydantic_core-2.47.0-cp311-cp311-win32.whl", hash = "sha256:d93c02ae8bd33f73624319d85cba47e754155c5bf104c0c5ca96fcd1f3094939"}, + {file = "pydantic_core-2.47.0-cp311-cp311-win_amd64.whl", hash = "sha256:859ca679f00e5feb11b58b616eb7bc0efbb13654be21f5c898e510e27671c900"}, + {file = "pydantic_core-2.47.0-cp311-cp311-win_arm64.whl", hash = "sha256:482667e5b7a3e97b0836f33a716199c4ec6ba9c896ca4db6eae799ec527c1e64"}, + {file = "pydantic_core-2.47.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:70a7aeba54854f5d97da65cb1a61f000f53df3704cab41cd81d65ab127ddd031"}, + {file = "pydantic_core-2.47.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c0e97329e38228f57fe1f2d91ba0ef39cc75cc1a84fe6ef58942d2fc6cb406bf"}, + {file = "pydantic_core-2.47.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:faf83e50714837af72f13e9369c50377552a4a74049d4477bed51c7e5822d94b"}, + {file = "pydantic_core-2.47.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f50abe60347bf8afe2b2f58db86cf3ac6e418eec7ffa01d9dc90ba29fc64f243"}, + {file = "pydantic_core-2.47.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f13b18e7dec056336f29ee77dea3cc5db0271d6215cac7249cc5c61b0a49d293"}, + {file = "pydantic_core-2.47.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3aaabbdbbbb8dff33fa053ffb2c980f39dec745fc03592f50e1e010449129841"}, + {file = "pydantic_core-2.47.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d2907ee6d15cf26787bfbeb4c42e18e52f358086eab91baea961a0d909248d6"}, + {file = "pydantic_core-2.47.0-cp312-cp312-manylinux_2_31_riscv64.whl", hash = "sha256:c413761260967f4dbb51135e1b49f30a1c29e15bf371fbb39754ed6475739545"}, + {file = "pydantic_core-2.47.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2bb527ac6e5a9721023b24615e1a55c01f47c60007f08b7d2afb89ff9c7a0e22"}, + {file = "pydantic_core-2.47.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b3abf7fd5e6abe483a63413f9cd26b7c93c20780e19c8556434c7279f6b2f10c"}, + {file = "pydantic_core-2.47.0-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:9cefe43d17a5a273d71697c084d3787defa7f578cd5fab4cef4c66d13c9e44b2"}, + {file = "pydantic_core-2.47.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:38a8b5371b7938e4f6c31060dbd51d3b3229aef7c43eaac2eb8b153e43c3f189"}, + {file = "pydantic_core-2.47.0-cp312-cp312-win32.whl", hash = "sha256:b87e95e644df2a36bd631dd0d6e097aa73d19a55adf7b1724ebdeece3d9c76b7"}, + {file = "pydantic_core-2.47.0-cp312-cp312-win_amd64.whl", hash = "sha256:a0078c5695322050ccedbc86eafaf3e2548439782c51d99e575de0e31b9fe4f4"}, + {file = "pydantic_core-2.47.0-cp312-cp312-win_arm64.whl", hash = "sha256:7eedc31996e9eba3bdfbbc380805ac6d765c889b7e93b17cd00ecb0200fd6dca"}, + {file = "pydantic_core-2.47.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:263560ece98bffbbc0a8047ce60b8a278c859db6a2a4e30d9454b02891045eca"}, + {file = "pydantic_core-2.47.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a7fdaff39a66bf66e9037da482575513d2f20bfb02ea9d9222b5cb3b902fc695"}, + {file = "pydantic_core-2.47.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b3e6c8ee5ce8c270bfae09763ae4bbbccfe81090c97d670a621fb86cb1ef6042"}, + {file = "pydantic_core-2.47.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e38cdae682cfec4b3816722dccf6376ca59049726d57dca83c2fe7cc13665589"}, + {file = "pydantic_core-2.47.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9fc3193ff0b7e2e168f84c6185e70475738c191f3154e0af8f897cd0f8f9a489"}, + {file = "pydantic_core-2.47.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:57ff41672a615f38af528ee904602be51c653248354e5db8e9252668abe91e68"}, + {file = "pydantic_core-2.47.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:473b9a2b2a1f0dd55cbb32d2b902f93babe7f141a0bb48fb4d3d4d2b3e93e9a0"}, + {file = "pydantic_core-2.47.0-cp313-cp313-manylinux_2_31_riscv64.whl", hash = "sha256:195f9c4ac43a7b2a044a7b86631c3352abdb820bed2823ea29f98f779255f459"}, + {file = "pydantic_core-2.47.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6c1cd10d39ef1ff8bcd68b6865bee9c434631ac0608d402fe86e678851c2e2a5"}, + {file = "pydantic_core-2.47.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:7cbe66352fe2b39511d49150e5b52159429cd21f5633a3e801dd2c43829dcdca"}, + {file = "pydantic_core-2.47.0-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:e35192a1d53e55d510d8bb1023c988c7cdae6d94539074971741b2a7656e49a1"}, + {file = "pydantic_core-2.47.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:81de54576de2e20baec76cc5afae2820f9049e6fdc4f357bac3391da02d0ba97"}, + {file = "pydantic_core-2.47.0-cp313-cp313-win32.whl", hash = "sha256:05da6647bdfd3888936ac10aa39b239d659f3c93dff281af0fc5943eb55629dd"}, + {file = "pydantic_core-2.47.0-cp313-cp313-win_amd64.whl", hash = "sha256:021220e0a03b66112737ee1fc49759340ce8fafb8d9ade1b7fb366b06033fa45"}, + {file = "pydantic_core-2.47.0-cp313-cp313-win_arm64.whl", hash = "sha256:55156ee2f6f561ea4e25ab55f84bd70b9c9ed2546a834cb2b038fe10225aaa37"}, + {file = "pydantic_core-2.47.0-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:6e37a6974fbd8fa7cae12285a76970d50b3689ffd6ed7c7fdd176ba81dd22d0e"}, + {file = "pydantic_core-2.47.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:2433b8524785cc117e602233bc574879bc8d87f09523edeec51665d5c46cf42d"}, + {file = "pydantic_core-2.47.0-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f8f9f2be064c8bf1189f46f7062fd42765d94f59cfb7db7ef8db19563192110a"}, + {file = "pydantic_core-2.47.0-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1f0a9659a2eb161573418e3138f616101ba21bbd2ff04916dca7b6712155e015"}, + {file = "pydantic_core-2.47.0-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2995074b99242aa28991e0120a3c881babc139e08750a05b7ea7d140644e091d"}, + {file = "pydantic_core-2.47.0-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c5b224dc04c3ff9b08c24419464eb7f6ad7a1049e12284a00bf80df82bd15fdb"}, + {file = "pydantic_core-2.47.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:264361b7236d4374fef6342908f87d084a0d58a2f8d0811e99f714309cb0ba7e"}, + {file = "pydantic_core-2.47.0-cp314-cp314-manylinux_2_31_riscv64.whl", hash = "sha256:53368beaf693f6302a6e33bdefe950857534a04d282811421bd20176d0fb5636"}, + {file = "pydantic_core-2.47.0-cp314-cp314-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5d2177b44ba7d9d86850f865f362feeaac6a2ed8517a9b505b97ff0b7fdbd7dd"}, + {file = "pydantic_core-2.47.0-cp314-cp314-musllinux_1_1_aarch64.whl", hash = "sha256:fb57a6b538ec7a01b937986bc093aec530fb056135b6bc9cfdd0bf8460c25bc2"}, + {file = "pydantic_core-2.47.0-cp314-cp314-musllinux_1_1_armv7l.whl", hash = "sha256:18c9c7c3a18e9bdbf1215d913f6bd00e17595dc92949817935cb87a3cf5f1697"}, + {file = "pydantic_core-2.47.0-cp314-cp314-musllinux_1_1_x86_64.whl", hash = "sha256:2f72a382886ca85bb1247303b9134cf9978c9d454de62e710a1ebcd9d2131927"}, + {file = "pydantic_core-2.47.0-cp314-cp314-pyemscripten_2026_0_wasm32.whl", hash = "sha256:cdf4dc2cdd0eacad1bd81c4d25422b4c25b206acae095d2d64e5d5cb7facc6b3"}, + {file = "pydantic_core-2.47.0-cp314-cp314-win32.whl", hash = "sha256:1e859dd5e06e9807080e14995db131649a77c61131cc464a7fe492a69ce82488"}, + {file = "pydantic_core-2.47.0-cp314-cp314-win_amd64.whl", hash = "sha256:234ecade0e358caa1ea516c218b3f61e61e30532cad1a8bb12f2487325838548"}, + {file = "pydantic_core-2.47.0-cp314-cp314-win_arm64.whl", hash = "sha256:58158d0111e86893bc35aacabe509f951ed303cddf8cdba43533190bde317914"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:24017fd3befd4d7cc4a8c71f4a1e9a44d29fdc91723c5446b0e795ab808adee7"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:dfa0820888cc4549fcce7e6bd8affa7d75198d885ccd0bb0760def4bb8461862"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1d507f331756f7066cde7c9f35ed78fa78223a54369dbc8a34d6da7a5074fa1"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5c4a885fb50c05903bc703a00830616e680304fdcdd90fc9535a52e72debe712"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7b107cbd764bf68f12a57c7aa5846d868bc7463490a1ac2d0f19bffef624c5a9"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0b5f46412c8226d0c8f1f423c324c75afe342e4b854836933579fb484f68598c"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90cbf7e35d597503cbdb5cd85409cbb75f377290bc7e8e37cc5dfe4f5cc66cf8"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-manylinux_2_31_riscv64.whl", hash = "sha256:04ee7ba7172cb4484af51b2890f19069d35773698abc8c6ebb651f52fbf41134"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2a7acb3e120ddba94372b8146e62ea3a0bce203180e34641c817c87f995c91e0"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-musllinux_1_1_aarch64.whl", hash = "sha256:1a2f7ceb58013d167d8c96f10ec9b3a137018c819ba356f68ff1cb74302fd22e"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-musllinux_1_1_armv7l.whl", hash = "sha256:482b097637fec5037eb13fe9f9d5fe47e568a5b451d686bc2b854076cc0b50ca"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-musllinux_1_1_x86_64.whl", hash = "sha256:a8c7f5fad73eb404f4b84c75f3d9d3865b748ded248b7366341db6e516fc502b"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-win32.whl", hash = "sha256:f343c39928097175acf2f7d0cba5c00b0f62265d88a173a4ce264266ac849bd9"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-win_amd64.whl", hash = "sha256:52d40e074da44e42b2425aedebd513e405a31807036ef597175117a9b01743a6"}, + {file = "pydantic_core-2.47.0-cp314-cp314t-win_arm64.whl", hash = "sha256:25cac08c9735e61e5c0b9f7a85c438661fc0e9da226afdfa984c3da6c5942a5a"}, + {file = "pydantic_core-2.47.0-graalpy311-graalpy242_311_native-macosx_10_12_x86_64.whl", hash = "sha256:dbe50ffd50c7e1b8e3bedafafd10fabe8a7cf51916f995fd57316eb1293f2439"}, + {file = "pydantic_core-2.47.0-graalpy311-graalpy242_311_native-macosx_11_0_arm64.whl", hash = "sha256:67c3f4b57a1cda846a9b6258e0ef70f99b0757d0b6a55f7e3e5b100620937d2c"}, + {file = "pydantic_core-2.47.0-graalpy311-graalpy242_311_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:33c1e025c83b3987784191ec5aa78cfb94686f1ba73d98ae4d520c09cdecda8d"}, + {file = "pydantic_core-2.47.0-graalpy311-graalpy242_311_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c547cac87585ab8a6e9495fd3be22ae8ccaabb3c909f6ae589a180677c105cc8"}, + {file = "pydantic_core-2.47.0-graalpy312-graalpy250_312_native-macosx_10_12_x86_64.whl", hash = "sha256:d648515c65f8871ceca6a0446be32fb1a0aae414db3907ec0df6cc2380dd0c04"}, + {file = "pydantic_core-2.47.0-graalpy312-graalpy250_312_native-macosx_11_0_arm64.whl", hash = "sha256:3292813ac8ed17671a5b1ac8256b6e98b96c29c1c6d061c35899e2f6e0444cac"}, + {file = "pydantic_core-2.47.0-graalpy312-graalpy250_312_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf3cfcea028f41a9b42a47f4a53d72ca4274d04e3ee525bd58ca89ea8c5e1910"}, + {file = "pydantic_core-2.47.0-graalpy312-graalpy250_312_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:52afdd8640d59890539060e91c8a494cf96b463e63b2ba499cc5c23abcb5e96b"}, + {file = "pydantic_core-2.47.0-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:ff201e8b93cd6a3849ec3fecb1d642d0391ca4a387f1162adaadddfd1986598c"}, + {file = "pydantic_core-2.47.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:eba39fca6e1fc7c58a5e6b5c320c6eec86014f7075dfa0fb1a79076e00304b3a"}, + {file = "pydantic_core-2.47.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5b494566e8e7e1a3992a550900902d6c0051538e53f5957e24003e7775638a24"}, + {file = "pydantic_core-2.47.0-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c96f745745fe778a92e74f40675c57b80dd46fc98186c9020e5ba4514a0470bd"}, + {file = "pydantic_core-2.47.0-pp311-pypy311_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:dbfb8f8ac48a8fd5a9f83cdbba0a23796f3ff5f9421c735808ed134d3318046f"}, + {file = "pydantic_core-2.47.0-pp311-pypy311_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:4ded2d8dc3beb086623ab3256b81a0b51fc026c94eb363f480aec09b204a1cbd"}, + {file = "pydantic_core-2.47.0-pp311-pypy311_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:25f8aa59fd3f85651387c37b289cd8f0bbc8f632e7dad13c6b837759e78be88d"}, + {file = "pydantic_core-2.47.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:0ae0096729c64aa6155cef74e75d1945017b01cfb82f313a18d843fafc497476"}, + {file = "pydantic_core-2.47.0.tar.gz", hash = "sha256:422c1797a7864b2a9a996435aba92fe571fb80190f67a31edbc1ac040c7b51fe"}, ] [package.dependencies] @@ -1149,23 +1136,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/exhaustive/pydantic-v1/src/seed/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/pydantic-v1/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/pydantic-v1/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/pydantic-v1/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/pydantic-v2-wrapped/poetry.lock b/seed/python-sdk/exhaustive/pydantic-v2-wrapped/poetry.lock index b7dcce559ab7..28b6313559f1 100644 --- a/seed/python-sdk/exhaustive/pydantic-v2-wrapped/poetry.lock +++ b/seed/python-sdk/exhaustive/pydantic-v2-wrapped/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/exhaustive/pydantic-v2-wrapped/src/seed/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/pydantic-v2-wrapped/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/pydantic-v2-wrapped/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/pydantic-v2-wrapped/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/pyproject_extras/poetry.lock b/seed/python-sdk/exhaustive/pyproject_extras/poetry.lock index 2eedbaf087b6..bb8e47ff41ce 100644 --- a/seed/python-sdk/exhaustive/pyproject_extras/poetry.lock +++ b/seed/python-sdk/exhaustive/pyproject_extras/poetry.lock @@ -1148,23 +1148,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/exhaustive/pyproject_extras/src/seed/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/pyproject_extras/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/pyproject_extras/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/pyproject_extras/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/skip-pydantic-validation/poetry.lock b/seed/python-sdk/exhaustive/skip-pydantic-validation/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/exhaustive/skip-pydantic-validation/poetry.lock +++ b/seed/python-sdk/exhaustive/skip-pydantic-validation/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/exhaustive/skip-pydantic-validation/src/seed/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/skip-pydantic-validation/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/skip-pydantic-validation/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/skip-pydantic-validation/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/union-utils/poetry.lock b/seed/python-sdk/exhaustive/union-utils/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/exhaustive/union-utils/poetry.lock +++ b/seed/python-sdk/exhaustive/union-utils/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/exhaustive/union-utils/src/seed/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/union-utils/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/union-utils/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/union-utils/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/exhaustive/wire-tests-custom-client-name/poetry.lock b/seed/python-sdk/exhaustive/wire-tests-custom-client-name/poetry.lock index fb3a0de06c97..45d91dca8189 100644 --- a/seed/python-sdk/exhaustive/wire-tests-custom-client-name/poetry.lock +++ b/seed/python-sdk/exhaustive/wire-tests-custom-client-name/poetry.lock @@ -1269,23 +1269,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/exhaustive/wire-tests-custom-client-name/src/seed/core/pydantic_utilities.py b/seed/python-sdk/exhaustive/wire-tests-custom-client-name/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/exhaustive/wire-tests-custom-client-name/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/exhaustive/wire-tests-custom-client-name/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/extends/poetry.lock b/seed/python-sdk/extends/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/extends/poetry.lock +++ b/seed/python-sdk/extends/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/extends/src/seed/core/pydantic_utilities.py b/seed/python-sdk/extends/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/extends/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/extends/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/extra-properties/poetry.lock b/seed/python-sdk/extra-properties/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/extra-properties/poetry.lock +++ b/seed/python-sdk/extra-properties/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/extra-properties/src/seed/core/pydantic_utilities.py b/seed/python-sdk/extra-properties/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/extra-properties/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/extra-properties/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/file-download/default-chunk-size/poetry.lock b/seed/python-sdk/file-download/default-chunk-size/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/file-download/default-chunk-size/poetry.lock +++ b/seed/python-sdk/file-download/default-chunk-size/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/file-download/default-chunk-size/src/seed/core/pydantic_utilities.py b/seed/python-sdk/file-download/default-chunk-size/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/file-download/default-chunk-size/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/file-download/default-chunk-size/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/file-download/no-custom-config/poetry.lock b/seed/python-sdk/file-download/no-custom-config/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/file-download/no-custom-config/poetry.lock +++ b/seed/python-sdk/file-download/no-custom-config/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/file-download/no-custom-config/src/seed/core/pydantic_utilities.py b/seed/python-sdk/file-download/no-custom-config/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/file-download/no-custom-config/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/file-download/no-custom-config/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/file-upload-openapi/poetry.lock b/seed/python-sdk/file-upload-openapi/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/file-upload-openapi/poetry.lock +++ b/seed/python-sdk/file-upload-openapi/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/file-upload-openapi/src/seed/core/pydantic_utilities.py b/seed/python-sdk/file-upload-openapi/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/file-upload-openapi/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/file-upload-openapi/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/file-upload/exclude_types_from_init_exports/poetry.lock b/seed/python-sdk/file-upload/exclude_types_from_init_exports/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/file-upload/exclude_types_from_init_exports/poetry.lock +++ b/seed/python-sdk/file-upload/exclude_types_from_init_exports/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/file-upload/exclude_types_from_init_exports/src/seed/core/pydantic_utilities.py b/seed/python-sdk/file-upload/exclude_types_from_init_exports/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/file-upload/exclude_types_from_init_exports/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/file-upload/exclude_types_from_init_exports/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/file-upload/no-custom-config/poetry.lock b/seed/python-sdk/file-upload/no-custom-config/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/file-upload/no-custom-config/poetry.lock +++ b/seed/python-sdk/file-upload/no-custom-config/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/file-upload/no-custom-config/src/seed/core/pydantic_utilities.py b/seed/python-sdk/file-upload/no-custom-config/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/file-upload/no-custom-config/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/file-upload/no-custom-config/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/file-upload/use_typeddict_requests/poetry.lock b/seed/python-sdk/file-upload/use_typeddict_requests/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/file-upload/use_typeddict_requests/poetry.lock +++ b/seed/python-sdk/file-upload/use_typeddict_requests/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/file-upload/use_typeddict_requests/src/seed/core/pydantic_utilities.py b/seed/python-sdk/file-upload/use_typeddict_requests/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/file-upload/use_typeddict_requests/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/file-upload/use_typeddict_requests/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/folders/poetry.lock b/seed/python-sdk/folders/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/folders/poetry.lock +++ b/seed/python-sdk/folders/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/folders/src/seed/core/pydantic_utilities.py b/seed/python-sdk/folders/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/folders/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/folders/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/header-auth-environment-variable/poetry.lock b/seed/python-sdk/header-auth-environment-variable/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/header-auth-environment-variable/poetry.lock +++ b/seed/python-sdk/header-auth-environment-variable/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/header-auth-environment-variable/src/seed/core/pydantic_utilities.py b/seed/python-sdk/header-auth-environment-variable/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/header-auth-environment-variable/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/header-auth-environment-variable/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/header-auth/poetry.lock b/seed/python-sdk/header-auth/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/header-auth/poetry.lock +++ b/seed/python-sdk/header-auth/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/header-auth/src/seed/core/pydantic_utilities.py b/seed/python-sdk/header-auth/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/header-auth/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/header-auth/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/http-head/poetry.lock b/seed/python-sdk/http-head/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/http-head/poetry.lock +++ b/seed/python-sdk/http-head/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/http-head/src/seed/core/pydantic_utilities.py b/seed/python-sdk/http-head/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/http-head/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/http-head/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/idempotency-headers/poetry.lock b/seed/python-sdk/idempotency-headers/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/idempotency-headers/poetry.lock +++ b/seed/python-sdk/idempotency-headers/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/idempotency-headers/src/seed/core/pydantic_utilities.py b/seed/python-sdk/idempotency-headers/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/idempotency-headers/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/idempotency-headers/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/imdb/poetry.lock b/seed/python-sdk/imdb/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/imdb/poetry.lock +++ b/seed/python-sdk/imdb/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/imdb/src/seed/core/pydantic_utilities.py b/seed/python-sdk/imdb/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/imdb/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/imdb/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/inferred-auth-explicit/poetry.lock b/seed/python-sdk/inferred-auth-explicit/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/inferred-auth-explicit/poetry.lock +++ b/seed/python-sdk/inferred-auth-explicit/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/inferred-auth-explicit/src/seed/core/pydantic_utilities.py b/seed/python-sdk/inferred-auth-explicit/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/inferred-auth-explicit/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/inferred-auth-explicit/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/inferred-auth-implicit-api-key/poetry.lock b/seed/python-sdk/inferred-auth-implicit-api-key/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/inferred-auth-implicit-api-key/poetry.lock +++ b/seed/python-sdk/inferred-auth-implicit-api-key/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/inferred-auth-implicit-api-key/src/seed/core/pydantic_utilities.py b/seed/python-sdk/inferred-auth-implicit-api-key/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/inferred-auth-implicit-api-key/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/inferred-auth-implicit-api-key/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/inferred-auth-implicit-no-expiry/poetry.lock b/seed/python-sdk/inferred-auth-implicit-no-expiry/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/inferred-auth-implicit-no-expiry/poetry.lock +++ b/seed/python-sdk/inferred-auth-implicit-no-expiry/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/inferred-auth-implicit-no-expiry/src/seed/core/pydantic_utilities.py b/seed/python-sdk/inferred-auth-implicit-no-expiry/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/inferred-auth-implicit-no-expiry/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/inferred-auth-implicit-no-expiry/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/inferred-auth-implicit-reference/poetry.lock b/seed/python-sdk/inferred-auth-implicit-reference/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/inferred-auth-implicit-reference/poetry.lock +++ b/seed/python-sdk/inferred-auth-implicit-reference/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/inferred-auth-implicit-reference/src/seed/core/pydantic_utilities.py b/seed/python-sdk/inferred-auth-implicit-reference/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/inferred-auth-implicit-reference/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/inferred-auth-implicit-reference/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/inferred-auth-implicit/poetry.lock b/seed/python-sdk/inferred-auth-implicit/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/inferred-auth-implicit/poetry.lock +++ b/seed/python-sdk/inferred-auth-implicit/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/inferred-auth-implicit/src/seed/core/pydantic_utilities.py b/seed/python-sdk/inferred-auth-implicit/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/inferred-auth-implicit/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/inferred-auth-implicit/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/license/poetry.lock b/seed/python-sdk/license/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/license/poetry.lock +++ b/seed/python-sdk/license/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/license/src/seed/core/pydantic_utilities.py b/seed/python-sdk/license/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/license/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/license/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/literal-user-agent/no-custom-config/poetry.lock b/seed/python-sdk/literal-user-agent/no-custom-config/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/literal-user-agent/no-custom-config/poetry.lock +++ b/seed/python-sdk/literal-user-agent/no-custom-config/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/literal-user-agent/no-custom-config/src/seed/core/pydantic_utilities.py b/seed/python-sdk/literal-user-agent/no-custom-config/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/literal-user-agent/no-custom-config/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/literal-user-agent/no-custom-config/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/literal/no-custom-config/poetry.lock b/seed/python-sdk/literal/no-custom-config/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/literal/no-custom-config/poetry.lock +++ b/seed/python-sdk/literal/no-custom-config/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/literal/no-custom-config/src/seed/core/pydantic_utilities.py b/seed/python-sdk/literal/no-custom-config/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/literal/no-custom-config/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/literal/no-custom-config/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/literal/use_typeddict_requests/poetry.lock b/seed/python-sdk/literal/use_typeddict_requests/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/literal/use_typeddict_requests/poetry.lock +++ b/seed/python-sdk/literal/use_typeddict_requests/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/literal/use_typeddict_requests/src/seed/core/pydantic_utilities.py b/seed/python-sdk/literal/use_typeddict_requests/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/literal/use_typeddict_requests/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/literal/use_typeddict_requests/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/literals-unions/poetry.lock b/seed/python-sdk/literals-unions/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/literals-unions/poetry.lock +++ b/seed/python-sdk/literals-unions/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/literals-unions/src/seed/core/pydantic_utilities.py b/seed/python-sdk/literals-unions/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/literals-unions/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/literals-unions/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/mixed-case/poetry.lock b/seed/python-sdk/mixed-case/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/mixed-case/poetry.lock +++ b/seed/python-sdk/mixed-case/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/mixed-case/src/seed/core/pydantic_utilities.py b/seed/python-sdk/mixed-case/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/mixed-case/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/mixed-case/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/mixed-file-directory/exclude_types_from_init_exports/poetry.lock b/seed/python-sdk/mixed-file-directory/exclude_types_from_init_exports/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/mixed-file-directory/exclude_types_from_init_exports/poetry.lock +++ b/seed/python-sdk/mixed-file-directory/exclude_types_from_init_exports/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/mixed-file-directory/exclude_types_from_init_exports/src/seed/core/pydantic_utilities.py b/seed/python-sdk/mixed-file-directory/exclude_types_from_init_exports/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/mixed-file-directory/exclude_types_from_init_exports/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/mixed-file-directory/exclude_types_from_init_exports/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/mixed-file-directory/no-custom-config/poetry.lock b/seed/python-sdk/mixed-file-directory/no-custom-config/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/mixed-file-directory/no-custom-config/poetry.lock +++ b/seed/python-sdk/mixed-file-directory/no-custom-config/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/mixed-file-directory/no-custom-config/src/seed/core/pydantic_utilities.py b/seed/python-sdk/mixed-file-directory/no-custom-config/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/mixed-file-directory/no-custom-config/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/mixed-file-directory/no-custom-config/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/multi-line-docs/poetry.lock b/seed/python-sdk/multi-line-docs/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/multi-line-docs/poetry.lock +++ b/seed/python-sdk/multi-line-docs/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/multi-line-docs/src/seed/core/pydantic_utilities.py b/seed/python-sdk/multi-line-docs/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/multi-line-docs/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/multi-line-docs/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/multi-url-environment-no-default/poetry.lock b/seed/python-sdk/multi-url-environment-no-default/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/multi-url-environment-no-default/poetry.lock +++ b/seed/python-sdk/multi-url-environment-no-default/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/multi-url-environment-no-default/src/seed/core/pydantic_utilities.py b/seed/python-sdk/multi-url-environment-no-default/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/multi-url-environment-no-default/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/multi-url-environment-no-default/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/multi-url-environment-reference/poetry.lock b/seed/python-sdk/multi-url-environment-reference/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/multi-url-environment-reference/poetry.lock +++ b/seed/python-sdk/multi-url-environment-reference/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/multi-url-environment-reference/src/seed/core/pydantic_utilities.py b/seed/python-sdk/multi-url-environment-reference/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/multi-url-environment-reference/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/multi-url-environment-reference/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/multi-url-environment/poetry.lock b/seed/python-sdk/multi-url-environment/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/multi-url-environment/poetry.lock +++ b/seed/python-sdk/multi-url-environment/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/multi-url-environment/src/seed/core/pydantic_utilities.py b/seed/python-sdk/multi-url-environment/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/multi-url-environment/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/multi-url-environment/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/multiple-request-bodies/poetry.lock b/seed/python-sdk/multiple-request-bodies/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/multiple-request-bodies/poetry.lock +++ b/seed/python-sdk/multiple-request-bodies/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/multiple-request-bodies/src/seed/core/pydantic_utilities.py b/seed/python-sdk/multiple-request-bodies/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/multiple-request-bodies/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/multiple-request-bodies/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/no-content-response/poetry.lock b/seed/python-sdk/no-content-response/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/no-content-response/poetry.lock +++ b/seed/python-sdk/no-content-response/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/no-content-response/src/seed/core/pydantic_utilities.py b/seed/python-sdk/no-content-response/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/no-content-response/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/no-content-response/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/no-environment/poetry.lock b/seed/python-sdk/no-environment/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/no-environment/poetry.lock +++ b/seed/python-sdk/no-environment/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/no-environment/src/seed/core/pydantic_utilities.py b/seed/python-sdk/no-environment/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/no-environment/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/no-environment/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/no-retries/poetry.lock b/seed/python-sdk/no-retries/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/no-retries/poetry.lock +++ b/seed/python-sdk/no-retries/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/no-retries/src/seed/core/pydantic_utilities.py b/seed/python-sdk/no-retries/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/no-retries/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/no-retries/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/null-type/poetry.lock b/seed/python-sdk/null-type/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/null-type/poetry.lock +++ b/seed/python-sdk/null-type/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/null-type/src/seed/core/pydantic_utilities.py b/seed/python-sdk/null-type/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/null-type/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/null-type/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/nullable-allof-extends/poetry.lock b/seed/python-sdk/nullable-allof-extends/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/nullable-allof-extends/poetry.lock +++ b/seed/python-sdk/nullable-allof-extends/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/nullable-allof-extends/src/seed/core/pydantic_utilities.py b/seed/python-sdk/nullable-allof-extends/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/nullable-allof-extends/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/nullable-allof-extends/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/nullable-optional/poetry.lock b/seed/python-sdk/nullable-optional/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/nullable-optional/poetry.lock +++ b/seed/python-sdk/nullable-optional/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/nullable-optional/src/seed/core/pydantic_utilities.py b/seed/python-sdk/nullable-optional/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/nullable-optional/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/nullable-optional/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/nullable-request-body/poetry.lock b/seed/python-sdk/nullable-request-body/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/nullable-request-body/poetry.lock +++ b/seed/python-sdk/nullable-request-body/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/nullable-request-body/src/seed/core/pydantic_utilities.py b/seed/python-sdk/nullable-request-body/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/nullable-request-body/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/nullable-request-body/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/nullable/no-custom-config/poetry.lock b/seed/python-sdk/nullable/no-custom-config/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/nullable/no-custom-config/poetry.lock +++ b/seed/python-sdk/nullable/no-custom-config/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/nullable/no-custom-config/src/seed/core/pydantic_utilities.py b/seed/python-sdk/nullable/no-custom-config/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/nullable/no-custom-config/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/nullable/no-custom-config/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/nullable/use-typeddict-requests/poetry.lock b/seed/python-sdk/nullable/use-typeddict-requests/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/nullable/use-typeddict-requests/poetry.lock +++ b/seed/python-sdk/nullable/use-typeddict-requests/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/nullable/use-typeddict-requests/src/seed/core/pydantic_utilities.py b/seed/python-sdk/nullable/use-typeddict-requests/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/nullable/use-typeddict-requests/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/nullable/use-typeddict-requests/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/oauth-client-credentials-custom/poetry.lock b/seed/python-sdk/oauth-client-credentials-custom/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/oauth-client-credentials-custom/poetry.lock +++ b/seed/python-sdk/oauth-client-credentials-custom/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/oauth-client-credentials-custom/src/seed/core/pydantic_utilities.py b/seed/python-sdk/oauth-client-credentials-custom/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/oauth-client-credentials-custom/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/oauth-client-credentials-custom/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/oauth-client-credentials-default/poetry.lock b/seed/python-sdk/oauth-client-credentials-default/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/oauth-client-credentials-default/poetry.lock +++ b/seed/python-sdk/oauth-client-credentials-default/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/oauth-client-credentials-default/src/seed/core/pydantic_utilities.py b/seed/python-sdk/oauth-client-credentials-default/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/oauth-client-credentials-default/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/oauth-client-credentials-default/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/oauth-client-credentials-environment-variables/poetry.lock b/seed/python-sdk/oauth-client-credentials-environment-variables/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/oauth-client-credentials-environment-variables/poetry.lock +++ b/seed/python-sdk/oauth-client-credentials-environment-variables/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/oauth-client-credentials-environment-variables/src/seed/core/pydantic_utilities.py b/seed/python-sdk/oauth-client-credentials-environment-variables/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/oauth-client-credentials-environment-variables/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/oauth-client-credentials-environment-variables/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/oauth-client-credentials-mandatory-auth/no-custom-config/poetry.lock b/seed/python-sdk/oauth-client-credentials-mandatory-auth/no-custom-config/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/oauth-client-credentials-mandatory-auth/no-custom-config/poetry.lock +++ b/seed/python-sdk/oauth-client-credentials-mandatory-auth/no-custom-config/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/oauth-client-credentials-mandatory-auth/no-custom-config/src/seed/core/pydantic_utilities.py b/seed/python-sdk/oauth-client-credentials-mandatory-auth/no-custom-config/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/oauth-client-credentials-mandatory-auth/no-custom-config/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/oauth-client-credentials-mandatory-auth/no-custom-config/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/oauth-client-credentials-nested-root/poetry.lock b/seed/python-sdk/oauth-client-credentials-nested-root/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/oauth-client-credentials-nested-root/poetry.lock +++ b/seed/python-sdk/oauth-client-credentials-nested-root/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/oauth-client-credentials-nested-root/src/seed/core/pydantic_utilities.py b/seed/python-sdk/oauth-client-credentials-nested-root/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/oauth-client-credentials-nested-root/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/oauth-client-credentials-nested-root/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/oauth-client-credentials-openapi/poetry.lock b/seed/python-sdk/oauth-client-credentials-openapi/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/oauth-client-credentials-openapi/poetry.lock +++ b/seed/python-sdk/oauth-client-credentials-openapi/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/oauth-client-credentials-openapi/src/seed/core/pydantic_utilities.py b/seed/python-sdk/oauth-client-credentials-openapi/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/oauth-client-credentials-openapi/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/oauth-client-credentials-openapi/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/oauth-client-credentials-reference/poetry.lock b/seed/python-sdk/oauth-client-credentials-reference/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/oauth-client-credentials-reference/poetry.lock +++ b/seed/python-sdk/oauth-client-credentials-reference/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/oauth-client-credentials-reference/src/seed/core/pydantic_utilities.py b/seed/python-sdk/oauth-client-credentials-reference/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/oauth-client-credentials-reference/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/oauth-client-credentials-reference/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/oauth-client-credentials-with-variables/poetry.lock b/seed/python-sdk/oauth-client-credentials-with-variables/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/oauth-client-credentials-with-variables/poetry.lock +++ b/seed/python-sdk/oauth-client-credentials-with-variables/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/oauth-client-credentials-with-variables/src/seed/core/pydantic_utilities.py b/seed/python-sdk/oauth-client-credentials-with-variables/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/oauth-client-credentials-with-variables/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/oauth-client-credentials-with-variables/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/oauth-client-credentials/poetry.lock b/seed/python-sdk/oauth-client-credentials/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/oauth-client-credentials/poetry.lock +++ b/seed/python-sdk/oauth-client-credentials/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/oauth-client-credentials/src/seed/core/pydantic_utilities.py b/seed/python-sdk/oauth-client-credentials/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/oauth-client-credentials/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/oauth-client-credentials/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/object/poetry.lock b/seed/python-sdk/object/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/object/poetry.lock +++ b/seed/python-sdk/object/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/object/src/seed/core/pydantic_utilities.py b/seed/python-sdk/object/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/object/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/object/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/objects-with-imports/poetry.lock b/seed/python-sdk/objects-with-imports/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/objects-with-imports/poetry.lock +++ b/seed/python-sdk/objects-with-imports/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/objects-with-imports/src/seed/core/pydantic_utilities.py b/seed/python-sdk/objects-with-imports/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/objects-with-imports/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/objects-with-imports/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/openapi-request-body-ref/poetry.lock b/seed/python-sdk/openapi-request-body-ref/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/openapi-request-body-ref/poetry.lock +++ b/seed/python-sdk/openapi-request-body-ref/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/openapi-request-body-ref/src/seed/core/pydantic_utilities.py b/seed/python-sdk/openapi-request-body-ref/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/openapi-request-body-ref/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/openapi-request-body-ref/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/optional/poetry.lock b/seed/python-sdk/optional/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/optional/poetry.lock +++ b/seed/python-sdk/optional/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/optional/src/seed/core/pydantic_utilities.py b/seed/python-sdk/optional/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/optional/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/optional/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/package-yml/poetry.lock b/seed/python-sdk/package-yml/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/package-yml/poetry.lock +++ b/seed/python-sdk/package-yml/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/package-yml/src/seed/core/pydantic_utilities.py b/seed/python-sdk/package-yml/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/package-yml/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/package-yml/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/pagination-custom/poetry.lock b/seed/python-sdk/pagination-custom/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/pagination-custom/poetry.lock +++ b/seed/python-sdk/pagination-custom/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/pagination-custom/src/seed/core/pydantic_utilities.py b/seed/python-sdk/pagination-custom/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/pagination-custom/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/pagination-custom/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/pagination-uri-path/poetry.lock b/seed/python-sdk/pagination-uri-path/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/pagination-uri-path/poetry.lock +++ b/seed/python-sdk/pagination-uri-path/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/pagination-uri-path/src/seed/core/pydantic_utilities.py b/seed/python-sdk/pagination-uri-path/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/pagination-uri-path/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/pagination-uri-path/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/pagination/no-custom-config/poetry.lock b/seed/python-sdk/pagination/no-custom-config/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/pagination/no-custom-config/poetry.lock +++ b/seed/python-sdk/pagination/no-custom-config/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/pagination/no-custom-config/src/seed/core/pydantic_utilities.py b/seed/python-sdk/pagination/no-custom-config/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/pagination/no-custom-config/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/pagination/no-custom-config/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/pagination/no-inheritance-for-extended-models/poetry.lock b/seed/python-sdk/pagination/no-inheritance-for-extended-models/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/pagination/no-inheritance-for-extended-models/poetry.lock +++ b/seed/python-sdk/pagination/no-inheritance-for-extended-models/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/pagination/no-inheritance-for-extended-models/src/seed/core/pydantic_utilities.py b/seed/python-sdk/pagination/no-inheritance-for-extended-models/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/pagination/no-inheritance-for-extended-models/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/pagination/no-inheritance-for-extended-models/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/pagination/page-index-semantics/poetry.lock b/seed/python-sdk/pagination/page-index-semantics/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/pagination/page-index-semantics/poetry.lock +++ b/seed/python-sdk/pagination/page-index-semantics/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/pagination/page-index-semantics/src/seed/core/pydantic_utilities.py b/seed/python-sdk/pagination/page-index-semantics/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/pagination/page-index-semantics/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/pagination/page-index-semantics/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/path-parameters/poetry.lock b/seed/python-sdk/path-parameters/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/path-parameters/poetry.lock +++ b/seed/python-sdk/path-parameters/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/path-parameters/src/seed/core/pydantic_utilities.py b/seed/python-sdk/path-parameters/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/path-parameters/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/path-parameters/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/plain-text/poetry.lock b/seed/python-sdk/plain-text/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/plain-text/poetry.lock +++ b/seed/python-sdk/plain-text/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/plain-text/src/seed/core/pydantic_utilities.py b/seed/python-sdk/plain-text/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/plain-text/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/plain-text/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/property-access/poetry.lock b/seed/python-sdk/property-access/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/property-access/poetry.lock +++ b/seed/python-sdk/property-access/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/property-access/src/seed/core/pydantic_utilities.py b/seed/python-sdk/property-access/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/property-access/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/property-access/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/public-object/poetry.lock b/seed/python-sdk/public-object/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/public-object/poetry.lock +++ b/seed/python-sdk/public-object/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/public-object/src/seed/core/pydantic_utilities.py b/seed/python-sdk/public-object/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/public-object/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/public-object/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/python-backslash-escape/poetry.lock b/seed/python-sdk/python-backslash-escape/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/python-backslash-escape/poetry.lock +++ b/seed/python-sdk/python-backslash-escape/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/python-backslash-escape/src/seed/core/pydantic_utilities.py b/seed/python-sdk/python-backslash-escape/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/python-backslash-escape/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/python-backslash-escape/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/python-mypy-exclude/no-custom-config/poetry.lock b/seed/python-sdk/python-mypy-exclude/no-custom-config/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/python-mypy-exclude/no-custom-config/poetry.lock +++ b/seed/python-sdk/python-mypy-exclude/no-custom-config/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/python-mypy-exclude/no-custom-config/src/seed/core/pydantic_utilities.py b/seed/python-sdk/python-mypy-exclude/no-custom-config/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/python-mypy-exclude/no-custom-config/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/python-mypy-exclude/no-custom-config/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/python-mypy-exclude/with-mypy-exclude/poetry.lock b/seed/python-sdk/python-mypy-exclude/with-mypy-exclude/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/python-mypy-exclude/with-mypy-exclude/poetry.lock +++ b/seed/python-sdk/python-mypy-exclude/with-mypy-exclude/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/python-mypy-exclude/with-mypy-exclude/src/seed/core/pydantic_utilities.py b/seed/python-sdk/python-mypy-exclude/with-mypy-exclude/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/python-mypy-exclude/with-mypy-exclude/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/python-mypy-exclude/with-mypy-exclude/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/python-positional-single-property/no-custom-config/poetry.lock b/seed/python-sdk/python-positional-single-property/no-custom-config/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/python-positional-single-property/no-custom-config/poetry.lock +++ b/seed/python-sdk/python-positional-single-property/no-custom-config/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/python-positional-single-property/no-custom-config/src/seed/core/pydantic_utilities.py b/seed/python-sdk/python-positional-single-property/no-custom-config/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/python-positional-single-property/no-custom-config/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/python-positional-single-property/no-custom-config/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/python-positional-single-property/with-positional-constructors/poetry.lock b/seed/python-sdk/python-positional-single-property/with-positional-constructors/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/python-positional-single-property/with-positional-constructors/poetry.lock +++ b/seed/python-sdk/python-positional-single-property/with-positional-constructors/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/python-positional-single-property/with-positional-constructors/src/seed/core/pydantic_utilities.py b/seed/python-sdk/python-positional-single-property/with-positional-constructors/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/python-positional-single-property/with-positional-constructors/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/python-positional-single-property/with-positional-constructors/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/python-reserved-keyword-subpackages/poetry.lock b/seed/python-sdk/python-reserved-keyword-subpackages/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/python-reserved-keyword-subpackages/poetry.lock +++ b/seed/python-sdk/python-reserved-keyword-subpackages/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/python-reserved-keyword-subpackages/src/seed/core/pydantic_utilities.py b/seed/python-sdk/python-reserved-keyword-subpackages/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/python-reserved-keyword-subpackages/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/python-reserved-keyword-subpackages/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/python-streaming-parameter-openapi/with-wire-tests/poetry.lock b/seed/python-sdk/python-streaming-parameter-openapi/with-wire-tests/poetry.lock index fb3a0de06c97..45d91dca8189 100644 --- a/seed/python-sdk/python-streaming-parameter-openapi/with-wire-tests/poetry.lock +++ b/seed/python-sdk/python-streaming-parameter-openapi/with-wire-tests/poetry.lock @@ -1269,23 +1269,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/python-streaming-parameter-openapi/with-wire-tests/src/seed/core/pydantic_utilities.py b/seed/python-sdk/python-streaming-parameter-openapi/with-wire-tests/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/python-streaming-parameter-openapi/with-wire-tests/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/python-streaming-parameter-openapi/with-wire-tests/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/query-param-name-conflict/poetry.lock b/seed/python-sdk/query-param-name-conflict/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/query-param-name-conflict/poetry.lock +++ b/seed/python-sdk/query-param-name-conflict/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/query-param-name-conflict/src/seed/core/pydantic_utilities.py b/seed/python-sdk/query-param-name-conflict/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/query-param-name-conflict/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/query-param-name-conflict/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/query-parameters-openapi-as-objects/no-custom-config/poetry.lock b/seed/python-sdk/query-parameters-openapi-as-objects/no-custom-config/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/query-parameters-openapi-as-objects/no-custom-config/poetry.lock +++ b/seed/python-sdk/query-parameters-openapi-as-objects/no-custom-config/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/query-parameters-openapi-as-objects/no-custom-config/src/seed/core/pydantic_utilities.py b/seed/python-sdk/query-parameters-openapi-as-objects/no-custom-config/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/query-parameters-openapi-as-objects/no-custom-config/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/query-parameters-openapi-as-objects/no-custom-config/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/query-parameters-openapi/no-custom-config/poetry.lock b/seed/python-sdk/query-parameters-openapi/no-custom-config/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/query-parameters-openapi/no-custom-config/poetry.lock +++ b/seed/python-sdk/query-parameters-openapi/no-custom-config/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/query-parameters-openapi/no-custom-config/src/seed/core/pydantic_utilities.py b/seed/python-sdk/query-parameters-openapi/no-custom-config/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/query-parameters-openapi/no-custom-config/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/query-parameters-openapi/no-custom-config/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/query-parameters/no-custom-config/poetry.lock b/seed/python-sdk/query-parameters/no-custom-config/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/query-parameters/no-custom-config/poetry.lock +++ b/seed/python-sdk/query-parameters/no-custom-config/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/query-parameters/no-custom-config/src/seed/core/pydantic_utilities.py b/seed/python-sdk/query-parameters/no-custom-config/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/query-parameters/no-custom-config/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/query-parameters/no-custom-config/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/request-parameters/poetry.lock b/seed/python-sdk/request-parameters/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/request-parameters/poetry.lock +++ b/seed/python-sdk/request-parameters/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/request-parameters/src/seed/core/pydantic_utilities.py b/seed/python-sdk/request-parameters/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/request-parameters/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/request-parameters/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/required-nullable/poetry.lock b/seed/python-sdk/required-nullable/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/required-nullable/poetry.lock +++ b/seed/python-sdk/required-nullable/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/required-nullable/src/seed/core/pydantic_utilities.py b/seed/python-sdk/required-nullable/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/required-nullable/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/required-nullable/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/reserved-keywords/poetry.lock b/seed/python-sdk/reserved-keywords/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/reserved-keywords/poetry.lock +++ b/seed/python-sdk/reserved-keywords/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/reserved-keywords/src/seed/core/pydantic_utilities.py b/seed/python-sdk/reserved-keywords/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/reserved-keywords/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/reserved-keywords/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/response-property/poetry.lock b/seed/python-sdk/response-property/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/response-property/poetry.lock +++ b/seed/python-sdk/response-property/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/response-property/src/seed/core/pydantic_utilities.py b/seed/python-sdk/response-property/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/response-property/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/response-property/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/schemaless-request-body-examples/poetry.lock b/seed/python-sdk/schemaless-request-body-examples/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/schemaless-request-body-examples/poetry.lock +++ b/seed/python-sdk/schemaless-request-body-examples/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/schemaless-request-body-examples/src/seed/core/pydantic_utilities.py b/seed/python-sdk/schemaless-request-body-examples/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/schemaless-request-body-examples/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/schemaless-request-body-examples/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/server-sent-event-examples/poetry.lock b/seed/python-sdk/server-sent-event-examples/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/server-sent-event-examples/poetry.lock +++ b/seed/python-sdk/server-sent-event-examples/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/server-sent-event-examples/src/seed/completions/raw_client.py b/seed/python-sdk/server-sent-event-examples/src/seed/completions/raw_client.py index 6c5fb131f460..ff80ffbf2f28 100644 --- a/seed/python-sdk/server-sent-event-examples/src/seed/completions/raw_client.py +++ b/seed/python-sdk/server-sent-event-examples/src/seed/completions/raw_client.py @@ -1,6 +1,7 @@ # This file was auto-generated by Fern from our API Definition. import contextlib +import json import typing from json.decoder import JSONDecodeError from logging import error, warning @@ -13,6 +14,11 @@ from ..core.pydantic_utilities import parse_obj_as, parse_sse_obj from ..core.request_options import RequestOptions from .errors.bad_request_error import BadRequestError +from .types.completion_event import CompletionEvent +from .types.error_event import ErrorEvent +from .types.event_event import EventEvent +from .types.group_created_event import GroupCreatedEvent +from .types.group_deleted_event import GroupDeletedEvent from .types.stream_event import StreamEvent from .types.stream_event_context_protocol import StreamEventContextProtocol from .types.stream_event_discriminant_in_data import StreamEventDiscriminantInData @@ -230,24 +236,28 @@ def _iter(): for _sse in _event_source.iter_sse(): if _sse.data == None: return - try: - yield typing.cast( - StreamEventDiscriminantInData, - parse_sse_obj( - sse=_sse, - type_=StreamEventDiscriminantInData, # type: ignore - ), - ) - except JSONDecodeError as e: - warning(f"Skipping SSE event with invalid JSON: {e}, sse: {_sse!r}") - except (TypeError, ValueError, KeyError, AttributeError) as e: - warning( - f"Skipping SSE event due to model construction error: {type(e).__name__}: {e}, sse: {_sse!r}" - ) - except Exception as e: - error( - f"Unexpected error processing SSE event: {type(e).__name__}: {e}, sse: {_sse!r}" - ) + if _sse.event == "group.created": + try: + yield typing.cast( + GroupCreatedEvent, + parse_obj_as( + type_=GroupCreatedEvent, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'group.created': {e}, sse: {_sse!r}") + elif _sse.event == "group.deleted": + try: + yield typing.cast( + GroupDeletedEvent, + parse_obj_as( + type_=GroupDeletedEvent, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'group.deleted': {e}, sse: {_sse!r}") return return HttpResponse(response=_response, data=_iter()) @@ -314,24 +324,39 @@ def _iter(): for _sse in _event_source.iter_sse(): if _sse.data == "[DONE]": return - try: - yield typing.cast( - StreamEventContextProtocol, - parse_sse_obj( - sse=_sse, - type_=StreamEventContextProtocol, # type: ignore - ), - ) - except JSONDecodeError as e: - warning(f"Skipping SSE event with invalid JSON: {e}, sse: {_sse!r}") - except (TypeError, ValueError, KeyError, AttributeError) as e: - warning( - f"Skipping SSE event due to model construction error: {type(e).__name__}: {e}, sse: {_sse!r}" - ) - except Exception as e: - error( - f"Unexpected error processing SSE event: {type(e).__name__}: {e}, sse: {_sse!r}" - ) + if _sse.event == "completion": + try: + yield typing.cast( + CompletionEvent, + parse_obj_as( + type_=CompletionEvent, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'completion': {e}, sse: {_sse!r}") + elif _sse.event == "error": + try: + yield typing.cast( + ErrorEvent, + parse_obj_as( + type_=ErrorEvent, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'error': {e}, sse: {_sse!r}") + elif _sse.event == "event": + try: + yield typing.cast( + EventEvent, + parse_obj_as( + type_=EventEvent, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'event': {e}, sse: {_sse!r}") return return HttpResponse(response=_response, data=_iter()) @@ -571,24 +596,28 @@ async def _iter(): async for _sse in _event_source.aiter_sse(): if _sse.data == None: return - try: - yield typing.cast( - StreamEventDiscriminantInData, - parse_sse_obj( - sse=_sse, - type_=StreamEventDiscriminantInData, # type: ignore - ), - ) - except JSONDecodeError as e: - warning(f"Skipping SSE event with invalid JSON: {e}, sse: {_sse!r}") - except (TypeError, ValueError, KeyError, AttributeError) as e: - warning( - f"Skipping SSE event due to model construction error: {type(e).__name__}: {e}, sse: {_sse!r}" - ) - except Exception as e: - error( - f"Unexpected error processing SSE event: {type(e).__name__}: {e}, sse: {_sse!r}" - ) + if _sse.event == "group.created": + try: + yield typing.cast( + GroupCreatedEvent, + parse_obj_as( + type_=GroupCreatedEvent, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'group.created': {e}, sse: {_sse!r}") + elif _sse.event == "group.deleted": + try: + yield typing.cast( + GroupDeletedEvent, + parse_obj_as( + type_=GroupDeletedEvent, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'group.deleted': {e}, sse: {_sse!r}") return return AsyncHttpResponse(response=_response, data=_iter()) @@ -655,24 +684,39 @@ async def _iter(): async for _sse in _event_source.aiter_sse(): if _sse.data == "[DONE]": return - try: - yield typing.cast( - StreamEventContextProtocol, - parse_sse_obj( - sse=_sse, - type_=StreamEventContextProtocol, # type: ignore - ), - ) - except JSONDecodeError as e: - warning(f"Skipping SSE event with invalid JSON: {e}, sse: {_sse!r}") - except (TypeError, ValueError, KeyError, AttributeError) as e: - warning( - f"Skipping SSE event due to model construction error: {type(e).__name__}: {e}, sse: {_sse!r}" - ) - except Exception as e: - error( - f"Unexpected error processing SSE event: {type(e).__name__}: {e}, sse: {_sse!r}" - ) + if _sse.event == "completion": + try: + yield typing.cast( + CompletionEvent, + parse_obj_as( + type_=CompletionEvent, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'completion': {e}, sse: {_sse!r}") + elif _sse.event == "error": + try: + yield typing.cast( + ErrorEvent, + parse_obj_as( + type_=ErrorEvent, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'error': {e}, sse: {_sse!r}") + elif _sse.event == "event": + try: + yield typing.cast( + EventEvent, + parse_obj_as( + type_=EventEvent, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'event': {e}, sse: {_sse!r}") return return AsyncHttpResponse(response=_response, data=_iter()) diff --git a/seed/python-sdk/server-sent-event-examples/src/seed/core/pydantic_utilities.py b/seed/python-sdk/server-sent-event-examples/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/server-sent-event-examples/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/server-sent-event-examples/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/server-sent-events-openapi/with-wire-tests/poetry.lock b/seed/python-sdk/server-sent-events-openapi/with-wire-tests/poetry.lock index fb3a0de06c97..45d91dca8189 100644 --- a/seed/python-sdk/server-sent-events-openapi/with-wire-tests/poetry.lock +++ b/seed/python-sdk/server-sent-events-openapi/with-wire-tests/poetry.lock @@ -1269,23 +1269,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/server-sent-events-openapi/with-wire-tests/src/seed/core/pydantic_utilities.py b/seed/python-sdk/server-sent-events-openapi/with-wire-tests/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/server-sent-events-openapi/with-wire-tests/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/server-sent-events-openapi/with-wire-tests/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/server-sent-events-openapi/with-wire-tests/src/seed/raw_client.py b/seed/python-sdk/server-sent-events-openapi/with-wire-tests/src/seed/raw_client.py index 0ebe88adf7b8..e1f1e4274537 100644 --- a/seed/python-sdk/server-sent-events-openapi/with-wire-tests/src/seed/raw_client.py +++ b/seed/python-sdk/server-sent-events-openapi/with-wire-tests/src/seed/raw_client.py @@ -16,7 +16,14 @@ from .core.serialization import convert_and_respect_annotation_metadata from .types.completion_full_response import CompletionFullResponse from .types.completion_stream_chunk import CompletionStreamChunk +from .types.data_context_entity_event import DataContextEntityEvent +from .types.data_context_heartbeat import DataContextHeartbeat from .types.event import Event +from .types.protocol_collision_object_event import ProtocolCollisionObjectEvent +from .types.protocol_heartbeat import ProtocolHeartbeat +from .types.protocol_number_event import ProtocolNumberEvent +from .types.protocol_object_event import ProtocolObjectEvent +from .types.protocol_string_event import ProtocolStringEvent from .types.stream_data_context_response import StreamDataContextResponse from .types.stream_data_context_with_envelope_schema_response import StreamDataContextWithEnvelopeSchemaResponse from .types.stream_no_context_response import StreamNoContextResponse @@ -74,24 +81,50 @@ def _iter(): for _sse in _event_source.iter_sse(): if _sse.data == None: return - try: - yield typing.cast( - StreamProtocolNoCollisionResponse, - parse_sse_obj( - sse=_sse, - type_=StreamProtocolNoCollisionResponse, # type: ignore - ), - ) - except JSONDecodeError as e: - warning(f"Skipping SSE event with invalid JSON: {e}, sse: {_sse!r}") - except (TypeError, ValueError, KeyError, AttributeError) as e: - warning( - f"Skipping SSE event due to model construction error: {type(e).__name__}: {e}, sse: {_sse!r}" - ) - except Exception as e: - error( - f"Unexpected error processing SSE event: {type(e).__name__}: {e}, sse: {_sse!r}" - ) + if _sse.event == "heartbeat": + try: + yield typing.cast( + ProtocolHeartbeat, + parse_obj_as( + type_=ProtocolHeartbeat, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'heartbeat': {e}, sse: {_sse!r}") + elif _sse.event == "string_data": + try: + yield typing.cast( + ProtocolStringEvent, + parse_obj_as( + type_=ProtocolStringEvent, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'string_data': {e}, sse: {_sse!r}") + elif _sse.event == "number_data": + try: + yield typing.cast( + ProtocolNumberEvent, + parse_obj_as( + type_=ProtocolNumberEvent, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'number_data': {e}, sse: {_sse!r}") + elif _sse.event == "object_data": + try: + yield typing.cast( + ProtocolObjectEvent, + parse_obj_as( + type_=ProtocolObjectEvent, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'object_data': {e}, sse: {_sse!r}") return return HttpResponse(response=_response, data=_iter()) @@ -150,24 +183,50 @@ def _iter(): for _sse in _event_source.iter_sse(): if _sse.data == None: return - try: - yield typing.cast( - StreamProtocolCollisionResponse, - parse_sse_obj( - sse=_sse, - type_=StreamProtocolCollisionResponse, # type: ignore - ), - ) - except JSONDecodeError as e: - warning(f"Skipping SSE event with invalid JSON: {e}, sse: {_sse!r}") - except (TypeError, ValueError, KeyError, AttributeError) as e: - warning( - f"Skipping SSE event due to model construction error: {type(e).__name__}: {e}, sse: {_sse!r}" - ) - except Exception as e: - error( - f"Unexpected error processing SSE event: {type(e).__name__}: {e}, sse: {_sse!r}" - ) + if _sse.event == "heartbeat": + try: + yield typing.cast( + ProtocolHeartbeat, + parse_obj_as( + type_=ProtocolHeartbeat, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'heartbeat': {e}, sse: {_sse!r}") + elif _sse.event == "string_data": + try: + yield typing.cast( + ProtocolStringEvent, + parse_obj_as( + type_=ProtocolStringEvent, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'string_data': {e}, sse: {_sse!r}") + elif _sse.event == "number_data": + try: + yield typing.cast( + ProtocolNumberEvent, + parse_obj_as( + type_=ProtocolNumberEvent, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'number_data': {e}, sse: {_sse!r}") + elif _sse.event == "object_data": + try: + yield typing.cast( + ProtocolCollisionObjectEvent, + parse_obj_as( + type_=ProtocolCollisionObjectEvent, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'object_data': {e}, sse: {_sse!r}") return return HttpResponse(response=_response, data=_iter()) @@ -378,24 +437,28 @@ def _iter(): for _sse in _event_source.iter_sse(): if _sse.data == None: return - try: - yield typing.cast( - StreamProtocolWithFlatSchemaResponse, - parse_sse_obj( - sse=_sse, - type_=StreamProtocolWithFlatSchemaResponse, # type: ignore - ), - ) - except JSONDecodeError as e: - warning(f"Skipping SSE event with invalid JSON: {e}, sse: {_sse!r}") - except (TypeError, ValueError, KeyError, AttributeError) as e: - warning( - f"Skipping SSE event due to model construction error: {type(e).__name__}: {e}, sse: {_sse!r}" - ) - except Exception as e: - error( - f"Unexpected error processing SSE event: {type(e).__name__}: {e}, sse: {_sse!r}" - ) + if _sse.event == "heartbeat": + try: + yield typing.cast( + DataContextHeartbeat, + parse_obj_as( + type_=DataContextHeartbeat, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'heartbeat': {e}, sse: {_sse!r}") + elif _sse.event == "entity": + try: + yield typing.cast( + DataContextEntityEvent, + parse_obj_as( + type_=DataContextEntityEvent, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'entity': {e}, sse: {_sse!r}") return return HttpResponse(response=_response, data=_iter()) @@ -1304,24 +1367,50 @@ async def _iter(): async for _sse in _event_source.aiter_sse(): if _sse.data == None: return - try: - yield typing.cast( - StreamProtocolNoCollisionResponse, - parse_sse_obj( - sse=_sse, - type_=StreamProtocolNoCollisionResponse, # type: ignore - ), - ) - except JSONDecodeError as e: - warning(f"Skipping SSE event with invalid JSON: {e}, sse: {_sse!r}") - except (TypeError, ValueError, KeyError, AttributeError) as e: - warning( - f"Skipping SSE event due to model construction error: {type(e).__name__}: {e}, sse: {_sse!r}" - ) - except Exception as e: - error( - f"Unexpected error processing SSE event: {type(e).__name__}: {e}, sse: {_sse!r}" - ) + if _sse.event == "heartbeat": + try: + yield typing.cast( + ProtocolHeartbeat, + parse_obj_as( + type_=ProtocolHeartbeat, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'heartbeat': {e}, sse: {_sse!r}") + elif _sse.event == "string_data": + try: + yield typing.cast( + ProtocolStringEvent, + parse_obj_as( + type_=ProtocolStringEvent, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'string_data': {e}, sse: {_sse!r}") + elif _sse.event == "number_data": + try: + yield typing.cast( + ProtocolNumberEvent, + parse_obj_as( + type_=ProtocolNumberEvent, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'number_data': {e}, sse: {_sse!r}") + elif _sse.event == "object_data": + try: + yield typing.cast( + ProtocolObjectEvent, + parse_obj_as( + type_=ProtocolObjectEvent, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'object_data': {e}, sse: {_sse!r}") return return AsyncHttpResponse(response=_response, data=_iter()) @@ -1380,24 +1469,50 @@ async def _iter(): async for _sse in _event_source.aiter_sse(): if _sse.data == None: return - try: - yield typing.cast( - StreamProtocolCollisionResponse, - parse_sse_obj( - sse=_sse, - type_=StreamProtocolCollisionResponse, # type: ignore - ), - ) - except JSONDecodeError as e: - warning(f"Skipping SSE event with invalid JSON: {e}, sse: {_sse!r}") - except (TypeError, ValueError, KeyError, AttributeError) as e: - warning( - f"Skipping SSE event due to model construction error: {type(e).__name__}: {e}, sse: {_sse!r}" - ) - except Exception as e: - error( - f"Unexpected error processing SSE event: {type(e).__name__}: {e}, sse: {_sse!r}" - ) + if _sse.event == "heartbeat": + try: + yield typing.cast( + ProtocolHeartbeat, + parse_obj_as( + type_=ProtocolHeartbeat, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'heartbeat': {e}, sse: {_sse!r}") + elif _sse.event == "string_data": + try: + yield typing.cast( + ProtocolStringEvent, + parse_obj_as( + type_=ProtocolStringEvent, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'string_data': {e}, sse: {_sse!r}") + elif _sse.event == "number_data": + try: + yield typing.cast( + ProtocolNumberEvent, + parse_obj_as( + type_=ProtocolNumberEvent, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'number_data': {e}, sse: {_sse!r}") + elif _sse.event == "object_data": + try: + yield typing.cast( + ProtocolCollisionObjectEvent, + parse_obj_as( + type_=ProtocolCollisionObjectEvent, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'object_data': {e}, sse: {_sse!r}") return return AsyncHttpResponse(response=_response, data=_iter()) @@ -1608,24 +1723,28 @@ async def _iter(): async for _sse in _event_source.aiter_sse(): if _sse.data == None: return - try: - yield typing.cast( - StreamProtocolWithFlatSchemaResponse, - parse_sse_obj( - sse=_sse, - type_=StreamProtocolWithFlatSchemaResponse, # type: ignore - ), - ) - except JSONDecodeError as e: - warning(f"Skipping SSE event with invalid JSON: {e}, sse: {_sse!r}") - except (TypeError, ValueError, KeyError, AttributeError) as e: - warning( - f"Skipping SSE event due to model construction error: {type(e).__name__}: {e}, sse: {_sse!r}" - ) - except Exception as e: - error( - f"Unexpected error processing SSE event: {type(e).__name__}: {e}, sse: {_sse!r}" - ) + if _sse.event == "heartbeat": + try: + yield typing.cast( + DataContextHeartbeat, + parse_obj_as( + type_=DataContextHeartbeat, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'heartbeat': {e}, sse: {_sse!r}") + elif _sse.event == "entity": + try: + yield typing.cast( + DataContextEntityEvent, + parse_obj_as( + type_=DataContextEntityEvent, # type: ignore + object_=json.loads(_sse.data), + ), + ) + except Exception as e: + warning(f"Failed to parse SSE event 'entity': {e}, sse: {_sse!r}") return return AsyncHttpResponse(response=_response, data=_iter()) diff --git a/seed/python-sdk/server-sent-events-resumable/poetry.lock b/seed/python-sdk/server-sent-events-resumable/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/server-sent-events-resumable/poetry.lock +++ b/seed/python-sdk/server-sent-events-resumable/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/server-sent-events-resumable/src/seed/core/pydantic_utilities.py b/seed/python-sdk/server-sent-events-resumable/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/server-sent-events-resumable/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/server-sent-events-resumable/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/server-sent-events/with-wire-tests/poetry.lock b/seed/python-sdk/server-sent-events/with-wire-tests/poetry.lock index fb3a0de06c97..45d91dca8189 100644 --- a/seed/python-sdk/server-sent-events/with-wire-tests/poetry.lock +++ b/seed/python-sdk/server-sent-events/with-wire-tests/poetry.lock @@ -1269,23 +1269,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/server-sent-events/with-wire-tests/src/seed/core/pydantic_utilities.py b/seed/python-sdk/server-sent-events/with-wire-tests/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/server-sent-events/with-wire-tests/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/server-sent-events/with-wire-tests/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/server-url-templating/no-custom-config/poetry.lock b/seed/python-sdk/server-url-templating/no-custom-config/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/server-url-templating/no-custom-config/poetry.lock +++ b/seed/python-sdk/server-url-templating/no-custom-config/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/server-url-templating/no-custom-config/src/seed/core/pydantic_utilities.py b/seed/python-sdk/server-url-templating/no-custom-config/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/server-url-templating/no-custom-config/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/server-url-templating/no-custom-config/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/simple-api/poetry.lock b/seed/python-sdk/simple-api/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/simple-api/poetry.lock +++ b/seed/python-sdk/simple-api/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/simple-api/src/seed/core/pydantic_utilities.py b/seed/python-sdk/simple-api/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/simple-api/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/simple-api/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/simple-fhir/no-inheritance-for-extended-models/poetry.lock b/seed/python-sdk/simple-fhir/no-inheritance-for-extended-models/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/simple-fhir/no-inheritance-for-extended-models/poetry.lock +++ b/seed/python-sdk/simple-fhir/no-inheritance-for-extended-models/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/simple-fhir/no-inheritance-for-extended-models/src/seed/core/pydantic_utilities.py b/seed/python-sdk/simple-fhir/no-inheritance-for-extended-models/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/simple-fhir/no-inheritance-for-extended-models/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/simple-fhir/no-inheritance-for-extended-models/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/single-url-environment-default/poetry.lock b/seed/python-sdk/single-url-environment-default/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/single-url-environment-default/poetry.lock +++ b/seed/python-sdk/single-url-environment-default/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/single-url-environment-default/src/seed/core/pydantic_utilities.py b/seed/python-sdk/single-url-environment-default/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/single-url-environment-default/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/single-url-environment-default/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/single-url-environment-no-default/poetry.lock b/seed/python-sdk/single-url-environment-no-default/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/single-url-environment-no-default/poetry.lock +++ b/seed/python-sdk/single-url-environment-no-default/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/single-url-environment-no-default/src/seed/core/pydantic_utilities.py b/seed/python-sdk/single-url-environment-no-default/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/single-url-environment-no-default/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/single-url-environment-no-default/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/streaming-parameter/poetry.lock b/seed/python-sdk/streaming-parameter/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/streaming-parameter/poetry.lock +++ b/seed/python-sdk/streaming-parameter/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/streaming-parameter/src/seed/core/pydantic_utilities.py b/seed/python-sdk/streaming-parameter/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/streaming-parameter/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/streaming-parameter/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/streaming/no-custom-config/poetry.lock b/seed/python-sdk/streaming/no-custom-config/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/streaming/no-custom-config/poetry.lock +++ b/seed/python-sdk/streaming/no-custom-config/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/streaming/no-custom-config/src/seed/core/pydantic_utilities.py b/seed/python-sdk/streaming/no-custom-config/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/streaming/no-custom-config/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/streaming/no-custom-config/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/streaming/skip-pydantic-validation/poetry.lock b/seed/python-sdk/streaming/skip-pydantic-validation/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/streaming/skip-pydantic-validation/poetry.lock +++ b/seed/python-sdk/streaming/skip-pydantic-validation/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/streaming/skip-pydantic-validation/src/seed/core/pydantic_utilities.py b/seed/python-sdk/streaming/skip-pydantic-validation/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/streaming/skip-pydantic-validation/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/streaming/skip-pydantic-validation/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/trace/poetry.lock b/seed/python-sdk/trace/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/trace/poetry.lock +++ b/seed/python-sdk/trace/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/trace/src/seed/core/pydantic_utilities.py b/seed/python-sdk/trace/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/trace/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/trace/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/undiscriminated-union-with-response-property/poetry.lock b/seed/python-sdk/undiscriminated-union-with-response-property/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/undiscriminated-union-with-response-property/poetry.lock +++ b/seed/python-sdk/undiscriminated-union-with-response-property/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/undiscriminated-union-with-response-property/src/seed/core/pydantic_utilities.py b/seed/python-sdk/undiscriminated-union-with-response-property/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/undiscriminated-union-with-response-property/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/undiscriminated-union-with-response-property/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/undiscriminated-unions/poetry.lock b/seed/python-sdk/undiscriminated-unions/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/undiscriminated-unions/poetry.lock +++ b/seed/python-sdk/undiscriminated-unions/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/undiscriminated-unions/src/seed/core/pydantic_utilities.py b/seed/python-sdk/undiscriminated-unions/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/undiscriminated-unions/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/undiscriminated-unions/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/union-query-parameters/poetry.lock b/seed/python-sdk/union-query-parameters/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/union-query-parameters/poetry.lock +++ b/seed/python-sdk/union-query-parameters/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/union-query-parameters/src/seed/core/pydantic_utilities.py b/seed/python-sdk/union-query-parameters/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/union-query-parameters/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/union-query-parameters/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/unions-with-local-date/poetry.lock b/seed/python-sdk/unions-with-local-date/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/unions-with-local-date/poetry.lock +++ b/seed/python-sdk/unions-with-local-date/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/unions-with-local-date/src/seed/core/pydantic_utilities.py b/seed/python-sdk/unions-with-local-date/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/unions-with-local-date/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/unions-with-local-date/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/unions/flatten-union-request-bodies/poetry.lock b/seed/python-sdk/unions/flatten-union-request-bodies/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/unions/flatten-union-request-bodies/poetry.lock +++ b/seed/python-sdk/unions/flatten-union-request-bodies/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/unions/flatten-union-request-bodies/src/seed/core/pydantic_utilities.py b/seed/python-sdk/unions/flatten-union-request-bodies/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/unions/flatten-union-request-bodies/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/unions/flatten-union-request-bodies/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/unions/no-custom-config/poetry.lock b/seed/python-sdk/unions/no-custom-config/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/unions/no-custom-config/poetry.lock +++ b/seed/python-sdk/unions/no-custom-config/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/unions/no-custom-config/src/seed/core/pydantic_utilities.py b/seed/python-sdk/unions/no-custom-config/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/unions/no-custom-config/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/unions/no-custom-config/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/unions/union-naming-v1-wire-tests/poetry.lock b/seed/python-sdk/unions/union-naming-v1-wire-tests/poetry.lock index fb3a0de06c97..45d91dca8189 100644 --- a/seed/python-sdk/unions/union-naming-v1-wire-tests/poetry.lock +++ b/seed/python-sdk/unions/union-naming-v1-wire-tests/poetry.lock @@ -1269,23 +1269,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/unions/union-naming-v1-wire-tests/src/seed/core/pydantic_utilities.py b/seed/python-sdk/unions/union-naming-v1-wire-tests/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/unions/union-naming-v1-wire-tests/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/unions/union-naming-v1-wire-tests/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/unions/union-naming-v1/poetry.lock b/seed/python-sdk/unions/union-naming-v1/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/unions/union-naming-v1/poetry.lock +++ b/seed/python-sdk/unions/union-naming-v1/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/unions/union-naming-v1/src/seed/core/pydantic_utilities.py b/seed/python-sdk/unions/union-naming-v1/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/unions/union-naming-v1/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/unions/union-naming-v1/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/unions/union-utils/poetry.lock b/seed/python-sdk/unions/union-utils/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/unions/union-utils/poetry.lock +++ b/seed/python-sdk/unions/union-utils/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/unions/union-utils/src/seed/core/pydantic_utilities.py b/seed/python-sdk/unions/union-utils/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/unions/union-utils/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/unions/union-utils/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/unknown/poetry.lock b/seed/python-sdk/unknown/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/unknown/poetry.lock +++ b/seed/python-sdk/unknown/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/unknown/src/seed/core/pydantic_utilities.py b/seed/python-sdk/unknown/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/unknown/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/unknown/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/url-form-encoded/poetry.lock b/seed/python-sdk/url-form-encoded/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/url-form-encoded/poetry.lock +++ b/seed/python-sdk/url-form-encoded/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/url-form-encoded/src/seed/core/pydantic_utilities.py b/seed/python-sdk/url-form-encoded/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/url-form-encoded/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/url-form-encoded/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/validation/no-custom-config/poetry.lock b/seed/python-sdk/validation/no-custom-config/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/validation/no-custom-config/poetry.lock +++ b/seed/python-sdk/validation/no-custom-config/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/validation/no-custom-config/src/seed/core/pydantic_utilities.py b/seed/python-sdk/validation/no-custom-config/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/validation/no-custom-config/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/validation/no-custom-config/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/validation/with-defaults-parameters/poetry.lock b/seed/python-sdk/validation/with-defaults-parameters/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/validation/with-defaults-parameters/poetry.lock +++ b/seed/python-sdk/validation/with-defaults-parameters/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/validation/with-defaults-parameters/src/seed/core/pydantic_utilities.py b/seed/python-sdk/validation/with-defaults-parameters/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/validation/with-defaults-parameters/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/validation/with-defaults-parameters/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/validation/with-defaults/poetry.lock b/seed/python-sdk/validation/with-defaults/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/validation/with-defaults/poetry.lock +++ b/seed/python-sdk/validation/with-defaults/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/validation/with-defaults/src/seed/core/pydantic_utilities.py b/seed/python-sdk/validation/with-defaults/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/validation/with-defaults/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/validation/with-defaults/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/variables/poetry.lock b/seed/python-sdk/variables/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/variables/poetry.lock +++ b/seed/python-sdk/variables/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/variables/src/seed/core/pydantic_utilities.py b/seed/python-sdk/variables/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/variables/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/variables/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/version-no-default/poetry.lock b/seed/python-sdk/version-no-default/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/version-no-default/poetry.lock +++ b/seed/python-sdk/version-no-default/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/version-no-default/src/seed/core/pydantic_utilities.py b/seed/python-sdk/version-no-default/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/version-no-default/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/version-no-default/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/version/poetry.lock b/seed/python-sdk/version/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/version/poetry.lock +++ b/seed/python-sdk/version/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/version/src/seed/core/pydantic_utilities.py b/seed/python-sdk/version/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/version/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/version/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/webhook-audience/poetry.lock b/seed/python-sdk/webhook-audience/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/webhook-audience/poetry.lock +++ b/seed/python-sdk/webhook-audience/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/webhook-audience/src/seed/core/pydantic_utilities.py b/seed/python-sdk/webhook-audience/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/webhook-audience/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/webhook-audience/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/webhooks/poetry.lock b/seed/python-sdk/webhooks/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/webhooks/poetry.lock +++ b/seed/python-sdk/webhooks/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/webhooks/src/seed/core/pydantic_utilities.py b/seed/python-sdk/webhooks/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/webhooks/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/webhooks/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/websocket-bearer-auth/poetry.lock b/seed/python-sdk/websocket-bearer-auth/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/websocket-bearer-auth/poetry.lock +++ b/seed/python-sdk/websocket-bearer-auth/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/websocket-bearer-auth/src/seed/core/pydantic_utilities.py b/seed/python-sdk/websocket-bearer-auth/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/websocket-bearer-auth/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/websocket-bearer-auth/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/websocket-inferred-auth/poetry.lock b/seed/python-sdk/websocket-inferred-auth/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/websocket-inferred-auth/poetry.lock +++ b/seed/python-sdk/websocket-inferred-auth/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/websocket-inferred-auth/src/seed/core/pydantic_utilities.py b/seed/python-sdk/websocket-inferred-auth/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/websocket-inferred-auth/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/websocket-inferred-auth/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/websocket-multi-url/poetry.lock b/seed/python-sdk/websocket-multi-url/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/websocket-multi-url/poetry.lock +++ b/seed/python-sdk/websocket-multi-url/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/websocket-multi-url/src/seed/core/pydantic_utilities.py b/seed/python-sdk/websocket-multi-url/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/websocket-multi-url/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/websocket-multi-url/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/websocket/websocket-base/poetry.lock b/seed/python-sdk/websocket/websocket-base/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/websocket/websocket-base/poetry.lock +++ b/seed/python-sdk/websocket/websocket-base/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/websocket/websocket-base/src/seed/core/pydantic_utilities.py b/seed/python-sdk/websocket/websocket-base/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/websocket/websocket-base/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/websocket/websocket-base/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/websocket/websocket-with_generated_clients-skip_validation/poetry.lock b/seed/python-sdk/websocket/websocket-with_generated_clients-skip_validation/poetry.lock index 1e0a847d3788..7be1340151fd 100644 --- a/seed/python-sdk/websocket/websocket-with_generated_clients-skip_validation/poetry.lock +++ b/seed/python-sdk/websocket/websocket-with_generated_clients-skip_validation/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/websocket/websocket-with_generated_clients-skip_validation/src/seed/core/pydantic_utilities.py b/seed/python-sdk/websocket/websocket-with_generated_clients-skip_validation/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/websocket/websocket-with_generated_clients-skip_validation/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/websocket/websocket-with_generated_clients-skip_validation/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/websocket/websocket-with_generated_clients/poetry.lock b/seed/python-sdk/websocket/websocket-with_generated_clients/poetry.lock index 1e0a847d3788..7be1340151fd 100644 --- a/seed/python-sdk/websocket/websocket-with_generated_clients/poetry.lock +++ b/seed/python-sdk/websocket/websocket-with_generated_clients/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/websocket/websocket-with_generated_clients/src/seed/core/pydantic_utilities.py b/seed/python-sdk/websocket/websocket-with_generated_clients/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/websocket/websocket-with_generated_clients/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/websocket/websocket-with_generated_clients/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {} diff --git a/seed/python-sdk/x-fern-default/poetry.lock b/seed/python-sdk/x-fern-default/poetry.lock index 1d4c695a34d6..4d94723fe480 100644 --- a/seed/python-sdk/x-fern-default/poetry.lock +++ b/seed/python-sdk/x-fern-default/poetry.lock @@ -1130,23 +1130,23 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests [[package]] name = "pytest-asyncio" -version = "1.3.0" +version = "1.4.0" description = "Pytest support for asyncio" optional = false python-versions = ">=3.10" groups = ["dev"] files = [ - {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, - {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, + {file = "pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1"}, + {file = "pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42"}, ] [package.dependencies] backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""} -pytest = ">=8.2,<10" +pytest = ">=8.4,<10" typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} [package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)", "sphinx-tabs (>=3.5)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] [[package]] diff --git a/seed/python-sdk/x-fern-default/src/seed/core/pydantic_utilities.py b/seed/python-sdk/x-fern-default/src/seed/core/pydantic_utilities.py index df3e720da4f8..6587f5e1820f 100644 --- a/seed/python-sdk/x-fern-default/src/seed/core/pydantic_utilities.py +++ b/seed/python-sdk/x-fern-default/src/seed/core/pydantic_utilities.py @@ -135,111 +135,21 @@ def _decimal_encoder(dec_value: Any) -> Any: Model = TypeVar("Model", bound=pydantic.BaseModel) -def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]: - """ - Extract the discriminator field name and union variants from a discriminated union type. - Supports Annotated[Union[...], Field(discriminator=...)] patterns. - Returns (discriminator, variants) or (None, None) if not a discriminated union. - """ - origin = typing_extensions.get_origin(type_) - - if origin is typing_extensions.Annotated: - args = typing_extensions.get_args(type_) - if len(args) >= 2: - inner_type = args[0] - # Check annotations for discriminator - discriminator = None - for annotation in args[1:]: - if hasattr(annotation, "discriminator"): - discriminator = getattr(annotation, "discriminator", None) - break - - if discriminator: - inner_origin = typing_extensions.get_origin(inner_type) - if inner_origin is Union: - variants = list(typing_extensions.get_args(inner_type)) - return discriminator, variants - return None, None - - -def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]: - """Get the type annotation of a field from a Pydantic model.""" - if IS_PYDANTIC_V2: - fields = getattr(model, "model_fields", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.annotation) - else: - fields = getattr(model, "__fields__", {}) - field_info = fields.get(field_name) - if field_info: - return cast(Optional[Type[Any]], field_info.outer_type_) - return None - - -def _find_variant_by_discriminator( - variants: List[Type[Any]], - discriminator: str, - discriminator_value: Any, -) -> Optional[Type[Any]]: - """Find the union variant that matches the discriminator value.""" - for variant in variants: - if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)): - continue - - disc_annotation = _get_field_annotation(variant, discriminator) - if disc_annotation and is_literal_type(disc_annotation): - literal_args = get_args(disc_annotation) - if literal_args and literal_args[0] == discriminator_value: - return variant - return None - - -def _is_string_type(type_: Type[Any]) -> bool: - """Check if a type is str or Optional[str].""" - if type_ is str: - return True - - origin = typing_extensions.get_origin(type_) - if origin is Union: - args = typing_extensions.get_args(type_) - # Optional[str] = Union[str, None] - non_none_args = [a for a in args if a is not type(None)] - if len(non_none_args) == 1 and non_none_args[0] is str: - return True - - return False - - def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: """ Parse a ServerSentEvent into the appropriate type. - Handles two scenarios based on where the discriminator field is located: + This function handles data-level discrimination where the discriminator + (e.g., 'type') is inside the 'data' payload. It parses the SSE data field + as JSON and deserializes it into the target type. - 1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload. - The union describes the data content, not the SSE envelope. - -> Returns: json.loads(data) parsed into the type - - Example: ChatStreamResponse with discriminator='type' - Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="") - Output: ContentDeltaEvent (parsed from data, SSE envelope stripped) - - 2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level. - The union describes the full SSE event structure. - -> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string - - Example: JobStreamResponse with discriminator='event' - Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123") - Output: JobStreamResponse_Error with data as ErrorData object - - But for variants where data is str (like STATUS_UPDATE): - Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1") - Output: JobStreamResponse_StatusUpdate with data as string (not parsed) + Note: Protocol-level discrimination (where the discriminator comes from + the SSE event: field) is handled at code-generation time and does not + use this function. Args: sse: The ServerSentEvent object to parse - type_: The target discriminated union type + type_: The target type to deserialize into Returns: The parsed object of type T @@ -248,66 +158,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T: This function is only available in SDK contexts where http_sse module exists. """ sse_event = asdict(sse) - discriminator, variants = _get_discriminator_and_variants(type_) - - if discriminator is None or variants is None: - # Not a discriminated union - parse the data field as JSON - data_value = sse_event.get("data") - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) - data_value = sse_event.get("data") - - # Check if discriminator is at the top level (event-level discrimination) - if discriminator in sse_event: - # Case 2: Event-level discrimination - # Find the matching variant to check if 'data' field needs JSON parsing - disc_value = sse_event.get(discriminator) - matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value) - - if matching_variant is not None: - # Check what type the variant expects for 'data' - data_type = _get_field_annotation(matching_variant, "data") - if data_type is not None and not _is_string_type(data_type): - # Variant expects non-string data - parse JSON - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - new_object = dict(sse_event) - new_object["data"] = parsed_data - return parse_obj_as(type_, new_object) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - # Either no matching variant, data is string type, or JSON parse failed - return parse_obj_as(type_, sse_event) - - else: - # Case 1: Data-level discrimination - # The discriminator is inside the data payload - extract and parse data only - if isinstance(data_value, str) and data_value: - try: - parsed_data = json.loads(data_value) - return parse_obj_as(type_, parsed_data) - except json.JSONDecodeError as e: - _logger.warning( - "Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s", - e, - data_value[:100] if len(data_value) > 100 else data_value, - ) - return parse_obj_as(type_, sse_event) + if isinstance(data_value, str) and data_value: + try: + parsed_data = json.loads(data_value) + return parse_obj_as(type_, parsed_data) + except json.JSONDecodeError as e: + _logger.warning( + "Failed to parse SSE data field as JSON: %s, data: %s", + e, + data_value[:100] if len(data_value) > 100 else data_value, + ) + return parse_obj_as(type_, sse_event) _type_adapter_cache: Dict[int, Any] = {}