Skip to content

feat(event_handler): add support for Pydantic models in Query and Header types - WIP #7076

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 65 additions & 19 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ def _openapi_operation_parameters(
from aws_lambda_powertools.event_handler.openapi.compat import (
get_schema_from_model_field,
)
from aws_lambda_powertools.event_handler.openapi.params import Param
from aws_lambda_powertools.event_handler.openapi.params import Form, Header, Param, Query

parameters = []
parameter: dict[str, Any] = {}
Expand All @@ -826,32 +826,78 @@ def _openapi_operation_parameters(
if not field_info.include_in_schema:
continue

param_schema = get_schema_from_model_field(
field=param,
model_name_map=model_name_map,
field_mapping=field_mapping,
)
# Check if this is a Pydantic model that should be expanded
from pydantic import BaseModel

parameter = {
"name": param.alias,
"in": field_info.in_.value,
"required": param.required,
"schema": param_schema,
}
from aws_lambda_powertools.event_handler.openapi.compat import lenient_issubclass

if isinstance(field_info, (Query, Header, Form)) and lenient_issubclass(field_info.annotation, BaseModel):
# Expand Pydantic model into individual parameters
model_class = field_info.annotation

if field_info.description:
parameter["description"] = field_info.description
for field_name, field_def in model_class.model_fields.items():
# Create individual parameter for each model field
param_name = field_def.alias or field_name

if field_info.openapi_examples:
parameter["examples"] = field_info.openapi_examples
# Convert snake_case to kebab-case for headers (HTTP convention)
if isinstance(field_info, Header):
param_name = param_name.replace("_", "-")

if field_info.deprecated:
parameter["deprecated"] = field_info.deprecated
individual_param = {
"name": param_name,
"in": field_info.in_.value,
"required": field_def.is_required()
if hasattr(field_def, "is_required")
else field_def.default is ...,
"schema": Route._get_basic_type_schema(field_def.annotation),
}

if field_def.description:
individual_param["description"] = field_def.description

parameters.append(individual_param)
else:
# Regular parameter processing
param_schema = get_schema_from_model_field(
field=param,
model_name_map=model_name_map,
field_mapping=field_mapping,
)

parameters.append(parameter)
parameter = {
"name": param.alias,
"in": field_info.in_.value,
"required": param.required,
"schema": param_schema,
}

if field_info.description:
parameter["description"] = field_info.description

if field_info.openapi_examples:
parameter["examples"] = field_info.openapi_examples

if field_info.deprecated:
parameter["deprecated"] = field_info.deprecated

parameters.append(parameter)

return parameters

@staticmethod
def _get_basic_type_schema(param_type: type) -> dict[str, str]:
"""
Get basic OpenAPI schema for simple types
"""
if isinstance(int, param_type):
return {"type": "integer"}
elif isinstance(float, param_type):
return {"type": "number"}
elif isinstance(bool, param_type):
return {"type": "boolean"}
else:
return {"type": "string"}

@staticmethod
def _openapi_operation_return(
*,
Expand Down
151 changes: 128 additions & 23 deletions aws_lambda_powertools/event_handler/middlewares/openapi_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence
from urllib.parse import parse_qs

from pydantic import BaseModel
from pydantic import BaseModel, ValidationError

from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler
from aws_lambda_powertools.event_handler.openapi.compat import (
Expand All @@ -19,7 +19,7 @@
from aws_lambda_powertools.event_handler.openapi.dependant import is_scalar_field
from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError, ResponseValidationError
from aws_lambda_powertools.event_handler.openapi.params import Param
from aws_lambda_powertools.event_handler.openapi.params import Header, Param, Query

if TYPE_CHECKING:
from aws_lambda_powertools.event_handler import Response
Expand Down Expand Up @@ -69,8 +69,8 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
route.dependant.query_params,
)

# Process query values
query_values, query_errors = _request_params_to_args(
# Process query values (with Pydantic model support)
query_values, query_errors = _request_params_to_args_with_pydantic_support(
route.dependant.query_params,
query_string,
)
Expand All @@ -81,8 +81,8 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
route.dependant.header_params,
)

# Process header values
header_values, header_errors = _request_params_to_args(
# Process header values (with Pydantic model support)
header_values, header_errors = _request_params_to_args_with_pydantic_support(
route.dependant.header_params,
headers,
)
Expand Down Expand Up @@ -311,6 +311,58 @@ def _prepare_response_content(
return res # pragma: no cover


def _request_params_to_args_with_pydantic_support(
required_params: Sequence[ModelField],
received_params: Mapping[str, Any],
) -> tuple[dict[str, Any], list[Any]]:
"""
Convert request params to a dictionary of values with Pydantic model support.
"""
values = {}
errors = []

for field in required_params:
field_info = field.field_info

# Check if this is a Pydantic model in Query/Header
from pydantic import BaseModel

from aws_lambda_powertools.event_handler.openapi.compat import lenient_issubclass

if isinstance(field_info, (Query, Header)) and lenient_issubclass(field_info.annotation, BaseModel):
# Handle Pydantic model - use the same approach as _request_body_to_args
loc = (field_info.in_.value, field.alias)

# Get the raw data for the Pydantic model
value = received_params.get(field.alias)

if value is None:
if field.required:
errors.append(get_missing_field_error(loc))
else:
values[field.name] = deepcopy(field.default)
continue

else:
# Regular parameter processing (existing logic)
if not isinstance(field_info, Param):
raise AssertionError(f"Expected Param field_info, got {field_info}")

value = received_params.get(field.alias)
loc = (field_info.in_.value, field.alias)

if value is None:
if field.required:
errors.append(get_missing_field_error(loc=loc))
else:
values[field.name] = deepcopy(field.default)
continue

# Use _validate_field like _request_body_to_args does
values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors)
return values, errors


def _request_params_to_args(
required_params: Sequence[ModelField],
received_params: Mapping[str, Any],
Expand Down Expand Up @@ -439,7 +491,7 @@ def _normalize_multi_query_string_with_param(
params: Sequence[ModelField],
) -> dict[str, Any]:
"""
Extract and normalize resolved_query_string_parameters
Extract and normalize resolved_query_string_parameters with Pydantic model support

Parameters
----------
Expand All @@ -453,19 +505,41 @@ def _normalize_multi_query_string_with_param(
A dictionary containing the processed multi_query_string_parameters.
"""
resolved_query_string: dict[str, Any] = query_string
for param in filter(is_scalar_field, params):
try:
# if the target parameter is a scalar, we keep the first value of the query string
# regardless if there are more in the payload
resolved_query_string[param.alias] = query_string[param.alias][0]
except KeyError:
pass

for param in params:
# Handle scalar fields (existing logic)
if is_scalar_field(param):
try:
resolved_query_string[param.alias] = query_string[param.alias][0]
except KeyError:
pass
# Handle Pydantic models
elif isinstance(param.field_info, Query) and hasattr(param.field_info, "annotation"):
from pydantic import BaseModel

from aws_lambda_powertools.event_handler.openapi.compat import lenient_issubclass

if lenient_issubclass(param.field_info.annotation, BaseModel):
model_class = param.field_info.annotation
model_data = {}

# Collect all fields for the Pydantic model
for field_name, field_def in model_class.model_fields.items():
field_alias = field_def.alias or field_name
try:
model_data[field_alias] = query_string[field_alias][0]
except KeyError:
pass

# Store the collected data under the param alias
resolved_query_string[param.alias] = model_data

return resolved_query_string


def _normalize_multi_header_values_with_param(headers: MutableMapping[str, Any], params: Sequence[ModelField]):
"""
Extract and normalize resolved_headers_field
Extract and normalize resolved_headers_field with Pydantic model support

Parameters
----------
Expand All @@ -479,12 +553,43 @@ def _normalize_multi_header_values_with_param(headers: MutableMapping[str, Any],
A dictionary containing the processed headers.
"""
if headers:
for param in filter(is_scalar_field, params):
try:
if len(headers[param.alias]) == 1:
# if the target parameter is a scalar and the list contains only 1 element
# we keep the first value of the headers regardless if there are more in the payload
headers[param.alias] = headers[param.alias][0]
except KeyError:
pass
for param in params:
# Handle scalar fields (existing logic)
if is_scalar_field(param):
try:
if len(headers[param.alias]) == 1:
headers[param.alias] = headers[param.alias][0]
except KeyError:
pass
# Handle Pydantic models
elif isinstance(param.field_info, Header) and hasattr(param.field_info, "annotation"):
from pydantic import BaseModel

from aws_lambda_powertools.event_handler.openapi.compat import lenient_issubclass

if lenient_issubclass(param.field_info.annotation, BaseModel):
model_class = param.field_info.annotation
model_data = {}

# Collect all fields for the Pydantic model
for field_name, field_def in model_class.model_fields.items():
field_alias = field_def.alias or field_name

# Convert snake_case to kebab-case for headers (HTTP convention)
header_key = field_alias.replace("_", "-")

try:
header_value = headers[header_key]
if isinstance(header_value, list):
if len(header_value) == 1:
model_data[field_alias] = header_value[0]
else:
model_data[field_alias] = header_value
else:
model_data[field_alias] = header_value
except KeyError:
pass

# Store the collected data under the param alias
headers[param.alias] = model_data
return headers
Loading
Loading