44import json
55import logging
66from copy import deepcopy
7- from typing import TYPE_CHECKING , Any , Callable , Mapping , MutableMapping , Sequence
7+ from typing import TYPE_CHECKING , Any , Callable , Mapping , MutableMapping , Sequence , get_origin
88from urllib .parse import parse_qs
99
1010from pydantic import BaseModel
1515 _normalize_errors ,
1616 _regenerate_error_with_loc ,
1717 get_missing_field_error ,
18+ lenient_issubclass ,
1819)
1920from aws_lambda_powertools .event_handler .openapi .dependant import is_scalar_field
2021from aws_lambda_powertools .event_handler .openapi .encoders import jsonable_encoder
2122from aws_lambda_powertools .event_handler .openapi .exceptions import RequestValidationError , ResponseValidationError
22- from aws_lambda_powertools .event_handler .openapi .params import Header , Param , Query
23+ from aws_lambda_powertools .event_handler .openapi .params import Param
2324
2425if TYPE_CHECKING :
2526 from aws_lambda_powertools .event_handler import Response
@@ -64,7 +65,7 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
6465 )
6566
6667 # Normalize query values before validate this
67- query_string = _normalize_multi_query_string_with_param (
68+ query_string = _normalize_multi_params (
6869 app .current_event .resolved_query_string_parameters ,
6970 route .dependant .query_params ,
7071 )
@@ -76,7 +77,7 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
7677 )
7778
7879 # Normalize header values before validate this
79- headers = _normalize_multi_header_values_with_param (
80+ headers = _normalize_multi_params (
8081 app .current_event .resolved_headers_field ,
8182 route .dependant .header_params ,
8283 )
@@ -316,38 +317,33 @@ def _request_params_to_args(
316317 received_params : Mapping [str , Any ],
317318) -> tuple [dict [str , Any ], list [Any ]]:
318319 """
319- Convert request params to a dictionary of values with Pydantic model support .
320+ Convert the request params to a dictionary of values using validation, and returns a list of errors .
320321 """
321322 values = {}
322323 errors = []
323324
324325 for field in required_params :
325326 field_info = field .field_info
326327
327- # Check if this is a Pydantic model in Query/Header
328- from pydantic import BaseModel
329-
330- from aws_lambda_powertools .event_handler .openapi .compat import lenient_issubclass
331-
332- if isinstance (field_info , (Query , Header )) and lenient_issubclass (field_info .annotation , BaseModel ):
333- pass
334- elif isinstance (field_info , Param ):
335- pass
336- else :
328+ # To ensure early failure, we check if it's not an instance of Param.
329+ if not isinstance (field_info , Param ):
337330 raise AssertionError (f"Expected Param field_info, got { field_info } " )
338331
339332 value = received_params .get (field .alias )
333+
340334 loc = (field_info .in_ .value , field .alias )
341335
336+ # If we don't have a value, see if it's required or has a default
342337 if value is None :
343338 if field .required :
344339 errors .append (get_missing_field_error (loc = loc ))
345340 else :
346341 values [field .name ] = deepcopy (field .default )
347342 continue
348343
349- # Use _validate_field like _request_body_to_args does
344+ # Finally, validate the value
350345 values [field .name ] = _validate_field (field = field , value = value , loc = loc , existing_errors = errors )
346+
351347 return values , errors
352348
353349
@@ -439,116 +435,53 @@ def _get_embed_body(
439435 return received_body , field_alias_omitted
440436
441437
442- def _normalize_multi_query_string_with_param (
443- query_string : dict [str , list [ str ] ],
438+ def _normalize_multi_params (
439+ input_dict : MutableMapping [str , Any ],
444440 params : Sequence [ModelField ],
445- ) -> dict [str , Any ]:
441+ ) -> MutableMapping [str , Any ]:
446442 """
447- Extract and normalize resolved_query_string_parameters with Pydantic model support
443+ Extract and normalize query string or header parameters with Pydantic model support.
448444
449445 Parameters
450446 ----------
451- query_string: dict
452- A dictionary containing the initial query string parameters.
447+ input_dict: MutableMapping[str, Any]
448+ A dictionary containing the initial query string or header parameters.
453449 params: Sequence[ModelField]
454450 A sequence of ModelField objects representing parameters.
455451
456452 Returns
457453 -------
458- A dictionary containing the processed multi_query_string_parameters.
454+ MutableMapping[str, Any]
455+ A dictionary containing the processed parameters with normalized values.
459456 """
460- resolved_query_string : dict [str , Any ] = query_string
461-
462457 for param in params :
463- # Handle scalar fields (existing logic)
464458 if is_scalar_field (param ):
465459 try :
466- resolved_query_string [param .alias ] = query_string [param .alias ][0 ]
460+ val = input_dict [param .alias ]
461+ if isinstance (val , list ) and len (val ) == 1 :
462+ input_dict [param .alias ] = val [0 ]
463+ elif isinstance (val , list ):
464+ pass # leave as list for multi-value
465+ # If it's a string, leave as is
467466 except KeyError :
468467 pass
469- # Handle Pydantic models
470- elif isinstance (param .field_info , Query ) and hasattr (param .field_info , "annotation" ):
471- from pydantic import BaseModel
472-
473- from aws_lambda_powertools .event_handler .openapi .compat import lenient_issubclass
474-
475- if lenient_issubclass (param .field_info .annotation , BaseModel ):
476- model_class = param .field_info .annotation
477- model_data = {}
478-
479- # Collect all fields for the Pydantic model
480- for field_name , field_def in model_class .model_fields .items ():
481- field_alias = field_def .alias or field_name
482- try :
483- model_data [field_alias ] = query_string [field_alias ][0 ]
484- except KeyError :
485- if model_class .model_config .get ("validate_by_name" ) or model_class .model_config .get (
486- "populate_by_name" ,
487- ):
488- try :
489- model_data [field_alias ] = query_string [field_name ][0 ]
490- except KeyError :
491- pass
492-
493- # Store the collected data under the param alias
494- resolved_query_string [param .alias ] = model_data
495-
496- return resolved_query_string
497-
498-
499- def _normalize_multi_header_values_with_param (headers : MutableMapping [str , Any ], params : Sequence [ModelField ]):
500- """
501- Extract and normalize resolved_headers_field with Pydantic model support
502-
503- Parameters
504- ----------
505- headers: MutableMapping[str, Any]
506- A dictionary containing the initial header parameters.
507- params: Sequence[ModelField]
508- A sequence of ModelField objects representing parameters.
509-
510- Returns
511- -------
512- A dictionary containing the processed headers.
513- """
514- if headers :
515- for param in params :
516- # Handle scalar fields (existing logic)
517- if is_scalar_field (param ):
518- try :
519- if len (headers [param .alias ]) == 1 :
520- headers [param .alias ] = headers [param .alias ][0 ]
521- except KeyError :
522- pass
523- # Handle Pydantic models
524- elif isinstance (param .field_info , Header ) and hasattr (param .field_info , "annotation" ):
525- from pydantic import BaseModel
526-
527- from aws_lambda_powertools .event_handler .openapi .compat import lenient_issubclass
528-
529- if lenient_issubclass (param .field_info .annotation , BaseModel ):
530- model_class = param .field_info .annotation
531- model_data = {}
532-
533- # Collect all fields for the Pydantic model
534- for field_name , field_def in model_class .model_fields .items ():
535- field_alias = field_def .alias or field_name
536-
537- # Convert snake_case to kebab-case for headers (HTTP convention)
538- header_key = field_alias .replace ("_" , "-" )
539-
540- try :
541- header_value = headers [header_key ]
542- if isinstance (header_value , list ):
543- if len (header_value ) == 1 :
544- model_data [field_alias ] = header_value [0 ]
545- else :
546- model_data [field_alias ] = header_value
547- else :
548- model_data [field_alias ] = header_value
549- except KeyError :
550- pass
551-
552- # Store the collected data under the param alias
553- headers [param .alias ] = model_data
554- return headers
468+ elif lenient_issubclass (param .field_info .annotation , BaseModel ):
469+ model_class = param .field_info .annotation
470+ model_data = {}
471+
472+ for field_name , field_def in model_class .model_fields .items ():
473+ field_alias = field_def .alias or field_name
474+ value = input_dict .get (field_alias )
475+ if value is None and (
476+ model_class .model_config .get ("validate_by_name" ) or model_class .model_config .get ("populate_by_name" )
477+ ):
478+ value = input_dict .get (field_name )
479+ if value is not None :
480+ if get_origin (field_def .annotation ) is list :
481+ model_data [field_alias ] = value
482+ elif isinstance (value , list ):
483+ model_data [field_alias ] = value [0 ]
484+ else :
485+ model_data [field_alias ] = value
486+ input_dict [param .alias ] = model_data
487+ return input_dict
0 commit comments