Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
174 changes: 18 additions & 156 deletions generators/python/core_utilities/shared/pydantic_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,111 +133,21 @@ def _decimal_encoder(dec_value: Any) -> Any:
Model = TypeVar("Model", bound=pydantic.BaseModel)


def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]:
"""
Extract the discriminator field name and union variants from a discriminated union type.
Supports Annotated[Union[...], Field(discriminator=...)] patterns.
Returns (discriminator, variants) or (None, None) if not a discriminated union.
"""
origin = typing_extensions.get_origin(type_)

if origin is typing_extensions.Annotated:
args = typing_extensions.get_args(type_)
if len(args) >= 2:
inner_type = args[0]
# Check annotations for discriminator
discriminator = None
for annotation in args[1:]:
if hasattr(annotation, "discriminator"):
discriminator = getattr(annotation, "discriminator", None)
break

if discriminator:
inner_origin = typing_extensions.get_origin(inner_type)
if inner_origin is Union:
variants = list(typing_extensions.get_args(inner_type))
return discriminator, variants
return None, None


def _get_field_annotation(model: Type[Any], field_name: str) -> Optional[Type[Any]]:
"""Get the type annotation of a field from a Pydantic model."""
if IS_PYDANTIC_V2:
fields = getattr(model, "model_fields", {})
field_info = fields.get(field_name)
if field_info:
return cast(Optional[Type[Any]], field_info.annotation)
else:
fields = getattr(model, "__fields__", {})
field_info = fields.get(field_name)
if field_info:
return cast(Optional[Type[Any]], field_info.outer_type_)
return None


def _find_variant_by_discriminator(
variants: List[Type[Any]],
discriminator: str,
discriminator_value: Any,
) -> Optional[Type[Any]]:
"""Find the union variant that matches the discriminator value."""
for variant in variants:
if not (inspect.isclass(variant) and issubclass(variant, pydantic.BaseModel)):
continue

disc_annotation = _get_field_annotation(variant, discriminator)
if disc_annotation and is_literal_type(disc_annotation):
literal_args = get_args(disc_annotation)
if literal_args and literal_args[0] == discriminator_value:
return variant
return None


def _is_string_type(type_: Type[Any]) -> bool:
"""Check if a type is str or Optional[str]."""
if type_ is str:
return True

origin = typing_extensions.get_origin(type_)
if origin is Union:
args = typing_extensions.get_args(type_)
# Optional[str] = Union[str, None]
non_none_args = [a for a in args if a is not type(None)]
if len(non_none_args) == 1 and non_none_args[0] is str:
return True

return False


def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T:
"""
Parse a ServerSentEvent into the appropriate type.

Handles two scenarios based on where the discriminator field is located:
This function handles data-level discrimination where the discriminator
(e.g., 'type') is inside the 'data' payload. It parses the SSE data field
as JSON and deserializes it into the target type.

1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload.
The union describes the data content, not the SSE envelope.
-> Returns: json.loads(data) parsed into the type

Example: ChatStreamResponse with discriminator='type'
Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="")
Output: ContentDeltaEvent (parsed from data, SSE envelope stripped)

2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level.
The union describes the full SSE event structure.
-> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string

Example: JobStreamResponse with discriminator='event'
Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123")
Output: JobStreamResponse_Error with data as ErrorData object

But for variants where data is str (like STATUS_UPDATE):
Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1")
Output: JobStreamResponse_StatusUpdate with data as string (not parsed)
Note: Protocol-level discrimination (where the discriminator comes from
the SSE event: field) is handled at code-generation time and does not
use this function.

Args:
sse: The ServerSentEvent object to parse
type_: The target discriminated union type
type_: The target type to deserialize into

Returns:
The parsed object of type T
Expand All @@ -246,66 +156,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T:
This function is only available in SDK contexts where http_sse module exists.
"""
sse_event = asdict(sse)
discriminator, variants = _get_discriminator_and_variants(type_)

if discriminator is None or variants is None:
# Not a discriminated union - parse the data field as JSON
data_value = sse_event.get("data")
if isinstance(data_value, str) and data_value:
try:
parsed_data = json.loads(data_value)
return parse_obj_as(type_, parsed_data)
except json.JSONDecodeError as e:
_logger.warning(
"Failed to parse SSE data field as JSON: %s, data: %s",
e,
data_value[:100] if len(data_value) > 100 else data_value,
)
return parse_obj_as(type_, sse_event)

data_value = sse_event.get("data")

# Check if discriminator is at the top level (event-level discrimination)
if discriminator in sse_event:
# Case 2: Event-level discrimination
# Find the matching variant to check if 'data' field needs JSON parsing
disc_value = sse_event.get(discriminator)
matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value)

if matching_variant is not None:
# Check what type the variant expects for 'data'
data_type = _get_field_annotation(matching_variant, "data")
if data_type is not None and not _is_string_type(data_type):
# Variant expects non-string data - parse JSON
if isinstance(data_value, str) and data_value:
try:
parsed_data = json.loads(data_value)
new_object = dict(sse_event)
new_object["data"] = parsed_data
return parse_obj_as(type_, new_object)
except json.JSONDecodeError as e:
_logger.warning(
"Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s",
e,
data_value[:100] if len(data_value) > 100 else data_value,
)
# Either no matching variant, data is string type, or JSON parse failed
return parse_obj_as(type_, sse_event)

else:
# Case 1: Data-level discrimination
# The discriminator is inside the data payload - extract and parse data only
if isinstance(data_value, str) and data_value:
try:
parsed_data = json.loads(data_value)
return parse_obj_as(type_, parsed_data)
except json.JSONDecodeError as e:
_logger.warning(
"Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s",
e,
data_value[:100] if len(data_value) > 100 else data_value,
)
return parse_obj_as(type_, sse_event)
if isinstance(data_value, str) and data_value:
try:
parsed_data = json.loads(data_value)
return parse_obj_as(type_, parsed_data)
except json.JSONDecodeError as e:
_logger.warning(
"Failed to parse SSE data field as JSON: %s, data: %s",
e,
data_value[:100] if len(data_value) > 100 else data_value,
)
return parse_obj_as(type_, sse_event)


_type_adapter_cache: Dict[int, Any] = {}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
# nopycln: file
import datetime as dt
import inspect
import json
import logging
import warnings
from collections import defaultdict
from dataclasses import asdict
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union, cast
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Tuple, Type, TypeVar, Union, cast

import pydantic
import typing_extensions
from .datetime_utils import serialize_datetime

if TYPE_CHECKING:
Expand Down Expand Up @@ -52,93 +50,21 @@ def parse_date(value: Any) -> dt.date:
AnyCallable = Callable[..., Any]


def _get_discriminator_and_variants(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]:
"""
Extract the discriminator field name and union variants from a discriminated union type.
Supports Annotated[Union[...], Field(discriminator=...)] patterns.
Returns (discriminator, variants) or (None, None) if not a discriminated union.
"""
origin = typing_extensions.get_origin(type_)

if origin is typing_extensions.Annotated:
args = typing_extensions.get_args(type_)
if len(args) >= 2:
inner_type = args[0]
# Check annotations for discriminator
discriminator = None
for annotation in args[1:]:
if hasattr(annotation, "discriminator"):
discriminator = getattr(annotation, "discriminator", None)
break

if discriminator:
inner_origin = typing_extensions.get_origin(inner_type)
if inner_origin is Union:
variants = list(typing_extensions.get_args(inner_type))
return discriminator, variants
return None, None


def _get_field_annotation_v1(model: Type[Any], field_name: str) -> Optional[Type[Any]]:
"""Get the type annotation of a field from a Pydantic v1 model."""
fields = getattr(model, "__fields__", {})
field_info = fields.get(field_name)
if field_info:
return cast(Optional[Type[Any]], field_info.outer_type_)
return None


def _find_variant_by_discriminator(
variants: List[Type[Any]],
discriminator: str,
discriminator_value: Any,
) -> Optional[Type[Any]]:
"""Find the union variant that matches the discriminator value."""
for variant in variants:
if not (inspect.isclass(variant) and issubclass(variant, pydantic.v1.BaseModel)):
continue

disc_annotation = _get_field_annotation_v1(variant, discriminator)
if disc_annotation and is_literal_type(disc_annotation):
literal_args = get_args(disc_annotation)
if literal_args and literal_args[0] == discriminator_value:
return variant
return None


def _is_string_type(type_: Type[Any]) -> bool:
"""Check if a type is str or Optional[str]."""
if type_ is str:
return True

origin = typing_extensions.get_origin(type_)
if origin is Union:
args = typing_extensions.get_args(type_)
# Optional[str] = Union[str, None]
non_none_args = [a for a in args if a is not type(None)]
if len(non_none_args) == 1 and non_none_args[0] is str:
return True

return False


def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T:
"""
Parse a ServerSentEvent into the appropriate type.

Handles two scenarios based on where the discriminator field is located:
This function handles data-level discrimination where the discriminator
(e.g., 'type') is inside the 'data' payload. It parses the SSE data field
as JSON and deserializes it into the target type.

1. Data-level discrimination: The discriminator (e.g., 'type') is inside the 'data' payload.
The union describes the data content, not the SSE envelope.
-> Returns: json.loads(data) parsed into the type

2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level.
The union describes the full SSE event structure.
-> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string
Note: Protocol-level discrimination (where the discriminator comes from
the SSE event: field) is handled at code-generation time and does not
use this function.

Args:
sse: The ServerSentEvent object to parse
type_: The target discriminated union type
type_: The target type to deserialize into

Returns:
The parsed object of type T
Expand All @@ -147,66 +73,18 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T:
This function is only available in SDK contexts where http_sse module exists.
"""
sse_event = asdict(sse)
discriminator, variants = _get_discriminator_and_variants(type_)

if discriminator is None or variants is None:
# Not a discriminated union - parse the data field as JSON
data_value = sse_event.get("data")
if isinstance(data_value, str) and data_value:
try:
parsed_data = json.loads(data_value)
return parse_obj_as(type_, parsed_data)
except json.JSONDecodeError as e:
_logger.warning(
"Failed to parse SSE data field as JSON: %s, data: %s",
e,
data_value[:100] if len(data_value) > 100 else data_value,
)
return parse_obj_as(type_, sse_event)

data_value = sse_event.get("data")

# Check if discriminator is at the top level (event-level discrimination)
if discriminator in sse_event:
# Case 2: Event-level discrimination
# Find the matching variant to check if 'data' field needs JSON parsing
disc_value = sse_event.get(discriminator)
matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value)

if matching_variant is not None:
# Check what type the variant expects for 'data'
data_type = _get_field_annotation_v1(matching_variant, "data")
if data_type is not None and not _is_string_type(data_type):
# Variant expects non-string data - parse JSON
if isinstance(data_value, str) and data_value:
try:
parsed_data = json.loads(data_value)
new_object = dict(sse_event)
new_object["data"] = parsed_data
return parse_obj_as(type_, new_object)
except json.JSONDecodeError as e:
_logger.warning(
"Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s",
e,
data_value[:100] if len(data_value) > 100 else data_value,
)
# Either no matching variant, data is string type, or JSON parse failed
return parse_obj_as(type_, sse_event)

else:
# Case 1: Data-level discrimination
# The discriminator is inside the data payload - extract and parse data only
if isinstance(data_value, str) and data_value:
try:
parsed_data = json.loads(data_value)
return parse_obj_as(type_, parsed_data)
except json.JSONDecodeError as e:
_logger.warning(
"Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s",
e,
data_value[:100] if len(data_value) > 100 else data_value,
)
return parse_obj_as(type_, sse_event)
if isinstance(data_value, str) and data_value:
try:
parsed_data = json.loads(data_value)
return parse_obj_as(type_, parsed_data)
except json.JSONDecodeError as e:
_logger.warning(
"Failed to parse SSE data field as JSON: %s, data: %s",
e,
data_value[:100] if len(data_value) > 100 else data_value,
)
return parse_obj_as(type_, sse_event)


def parse_obj_as(type_: Type[T], object_: Any) -> T:
Expand Down
Loading
Loading