From f6ca987403a11f5a4c32b19ba3a41cff85b21d45 Mon Sep 17 00:00:00 2001 From: dap0am Date: Thu, 21 Aug 2025 09:27:07 +0100 Subject: [PATCH 01/12] feat(event-handler): add support for Pydantic Field discriminator in validation (#5953) Enable use of Field(discriminator='...') with tagged unions in event handler validation. This allows developers to use Pydantic's native discriminator syntax instead of requiring Powertools-specific Param annotations. - Handle Field(discriminator) + Body() combination in get_field_info_annotated_type - Preserve discriminator metadata when creating TypeAdapter in ModelField - Add comprehensive tests for discriminator validation and Field features --- .../event_handler/openapi/compat.py | 17 +++- .../event_handler/openapi/params.py | 52 ++++++++-- .../test_openapi_validation_middleware.py | 96 ++++++++++++++++++- 3 files changed, 152 insertions(+), 13 deletions(-) diff --git a/aws_lambda_powertools/event_handler/openapi/compat.py b/aws_lambda_powertools/event_handler/openapi/compat.py index d3340f34e4b..d9c975e3396 100644 --- a/aws_lambda_powertools/event_handler/openapi/compat.py +++ b/aws_lambda_powertools/event_handler/openapi/compat.py @@ -80,9 +80,20 @@ def type_(self) -> Any: return self.field_info.annotation def __post_init__(self) -> None: - self._type_adapter: TypeAdapter[Any] = TypeAdapter( - Annotated[self.field_info.annotation, self.field_info], - ) + + # If the field_info.annotation is already an Annotated type with discriminator metadata, + # use it directly instead of wrapping it again + annotation = self.field_info.annotation + if ( + get_origin(annotation) is Annotated + and hasattr(self.field_info, "discriminator") + and self.field_info.discriminator is not None + ): + self._type_adapter: TypeAdapter[Any] = TypeAdapter(annotation) + else: + self._type_adapter: TypeAdapter[Any] = TypeAdapter( + Annotated[annotation, self.field_info], + ) def get_default(self) -> Any: if self.field_info.is_required(): diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 8fc8d0becfa..3743ac0eff7 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -1046,17 +1046,47 @@ def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tup type_annotation = annotated_args[0] powertools_annotations = [arg for arg in annotated_args[1:] if isinstance(arg, FieldInfo)] - if len(powertools_annotations) > 1: + # Special case: handle Field(discriminator) + Body() combination + # This happens when using Annotated[Union[A, B], Field(discriminator='...')] with Body() + has_discriminator_with_body = False + powertools_annotation: FieldInfo | None = None + + if len(powertools_annotations) == 2: + field_obj = None + body_obj = None + for ann in powertools_annotations: + if isinstance(ann, Body): + body_obj = ann + elif isinstance(ann, FieldInfo) and hasattr(ann, "discriminator") and ann.discriminator is not None: + field_obj = ann + + if field_obj and body_obj: + # Use Body as the primary annotation + powertools_annotation = body_obj + # Preserve the full annotation including discriminator for proper validation + # This ensures the discriminator is available when creating the TypeAdapter + type_annotation = annotation + has_discriminator_with_body = True + else: + raise AssertionError("Only one FieldInfo can be used per parameter") + elif len(powertools_annotations) > 1: raise AssertionError("Only one FieldInfo can be used per parameter") - - powertools_annotation = next(iter(powertools_annotations), None) + else: + powertools_annotation = next(iter(powertools_annotations), None) if isinstance(powertools_annotation, FieldInfo): - # Copy `field_info` because we mutate `field_info.default` later - field_info = copy_field_info( - field_info=powertools_annotation, - annotation=annotation, - ) + if has_discriminator_with_body: + # For discriminator + Body case, create a new Body instance directly + # This avoids issues with copy_field_info trying to process the Field + field_info = Body() + field_info.annotation = type_annotation + else: + # Copy `field_info` because we mutate `field_info.default` later + # Use the possibly modified type_annotation for copy_field_info + field_info = copy_field_info( + field_info=powertools_annotation, + annotation=type_annotation, + ) if field_info.default not in [Undefined, Required]: raise AssertionError("FieldInfo needs to have a default value of Undefined or Required") @@ -1067,6 +1097,12 @@ def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tup else: field_info.default = Required + # Preserve the full annotated type if it contains discriminator metadata + # This is crucial for tagged unions to work properly + if hasattr(powertools_annotation, "discriminator") and powertools_annotation.discriminator is not None: + # Store the full annotated type for discriminated unions + type_annotation = annotation + return field_info, type_annotation diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index 1fd919b7b71..9c07a7313ad 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -3,10 +3,10 @@ from dataclasses import dataclass from enum import Enum from pathlib import PurePath -from typing import List, Optional, Tuple +from typing import List, Literal, Optional, Tuple import pytest -from pydantic import BaseModel +from pydantic import BaseModel, Field from typing_extensions import Annotated from aws_lambda_powertools.event_handler import ( @@ -1983,3 +1983,95 @@ def get_user(user_id: int) -> UserModel: assert response_body["name"] == "User123" assert response_body["age"] == 143 assert response_body["email"] == "user123@example.com" + + +def test_field_discriminator_validation(gw_event): + """Test that Pydantic Field discriminator works with event_handler validation""" + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class FooAction(BaseModel): + action: Literal["foo"] + foo_data: str + + class BarAction(BaseModel): + action: Literal["bar"] + bar_data: int + + # This should work with Field discriminator (issue #5953) + Action = Annotated[FooAction | BarAction, Field(discriminator="action")] + + @app.post("/actions") + def create_action(action: Annotated[Action, Body()]): + return {"received_action": action.action, "data": action.model_dump()} + + # WHEN sending a valid foo action + gw_event["path"] = "/actions" + gw_event["httpMethod"] = "POST" + gw_event["headers"]["content-type"] = "application/json" + gw_event["body"] = '{"action": "foo", "foo_data": "test"}' + + # THEN the handler should be invoked and return 200 + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + response_body = json.loads(result["body"]) + assert response_body["received_action"] == "foo" + assert response_body["data"]["action"] == "foo" + assert response_body["data"]["foo_data"] == "test" + + # WHEN sending a valid bar action + gw_event["body"] = '{"action": "bar", "bar_data": 123}' + + # THEN the handler should be invoked and return 200 + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + response_body = json.loads(result["body"]) + assert response_body["received_action"] == "bar" + assert response_body["data"]["action"] == "bar" + assert response_body["data"]["bar_data"] == 123 + + # WHEN sending an invalid discriminator + gw_event["body"] = '{"action": "invalid", "some_data": "test"}' + + # THEN the handler should return 422 (validation error) + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + +def test_field_other_features_still_work(gw_event): + """Test that other Field features still work after discriminator fix""" + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class UserInput(BaseModel): + name: Annotated[str, Field(min_length=2, max_length=50, description="User name")] + age: Annotated[int, Field(ge=18, le=120, description="User age")] + email: Annotated[str, Field(pattern=r".+@.+\..+", description="User email")] + + @app.post("/users") + def create_user(user: UserInput): + return {"created": user.model_dump()} + + # WHEN sending valid data + gw_event["path"] = "/users" + gw_event["httpMethod"] = "POST" + gw_event["headers"]["content-type"] = "application/json" + gw_event["body"] = '{"name": "John", "age": 25, "email": "john@example.com"}' + + # THEN the handler should return 200 + result = app(gw_event, {}) + assert result["statusCode"] == 200 + + response_body = json.loads(result["body"]) + assert response_body["created"]["name"] == "John" + assert response_body["created"]["age"] == 25 + assert response_body["created"]["email"] == "john@example.com" + + # WHEN sending data with validation error (age too low) + gw_event["body"] = '{"name": "John", "age": 16, "email": "john@example.com"}' + + # THEN the handler should return 422 (validation error) + result = app(gw_event, {}) + assert result["statusCode"] == 422 From 63a0225f147c2abd9ad2ded9d0e2fbbedf32cea9 Mon Sep 17 00:00:00 2001 From: dap0am Date: Thu, 21 Aug 2025 09:43:04 +0100 Subject: [PATCH 02/12] style(tests): remove inline comments to match project test style --- .../_pydantic/test_openapi_validation_middleware.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index 9c07a7313ad..dc55cd55772 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -1987,7 +1987,6 @@ def get_user(user_id: int) -> UserModel: def test_field_discriminator_validation(gw_event): """Test that Pydantic Field discriminator works with event_handler validation""" - # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) class FooAction(BaseModel): @@ -1998,20 +1997,17 @@ class BarAction(BaseModel): action: Literal["bar"] bar_data: int - # This should work with Field discriminator (issue #5953) Action = Annotated[FooAction | BarAction, Field(discriminator="action")] @app.post("/actions") def create_action(action: Annotated[Action, Body()]): return {"received_action": action.action, "data": action.model_dump()} - # WHEN sending a valid foo action gw_event["path"] = "/actions" gw_event["httpMethod"] = "POST" gw_event["headers"]["content-type"] = "application/json" gw_event["body"] = '{"action": "foo", "foo_data": "test"}' - # THEN the handler should be invoked and return 200 result = app(gw_event, {}) assert result["statusCode"] == 200 @@ -2020,10 +2016,8 @@ def create_action(action: Annotated[Action, Body()]): assert response_body["data"]["action"] == "foo" assert response_body["data"]["foo_data"] == "test" - # WHEN sending a valid bar action gw_event["body"] = '{"action": "bar", "bar_data": 123}' - # THEN the handler should be invoked and return 200 result = app(gw_event, {}) assert result["statusCode"] == 200 @@ -2032,17 +2026,14 @@ def create_action(action: Annotated[Action, Body()]): assert response_body["data"]["action"] == "bar" assert response_body["data"]["bar_data"] == 123 - # WHEN sending an invalid discriminator gw_event["body"] = '{"action": "invalid", "some_data": "test"}' - # THEN the handler should return 422 (validation error) result = app(gw_event, {}) assert result["statusCode"] == 422 def test_field_other_features_still_work(gw_event): """Test that other Field features still work after discriminator fix""" - # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) class UserInput(BaseModel): @@ -2054,13 +2045,11 @@ class UserInput(BaseModel): def create_user(user: UserInput): return {"created": user.model_dump()} - # WHEN sending valid data gw_event["path"] = "/users" gw_event["httpMethod"] = "POST" gw_event["headers"]["content-type"] = "application/json" gw_event["body"] = '{"name": "John", "age": 25, "email": "john@example.com"}' - # THEN the handler should return 200 result = app(gw_event, {}) assert result["statusCode"] == 200 @@ -2069,9 +2058,7 @@ def create_user(user: UserInput): assert response_body["created"]["age"] == 25 assert response_body["created"]["email"] == "john@example.com" - # WHEN sending data with validation error (age too low) gw_event["body"] = '{"name": "John", "age": 16, "email": "john@example.com"}' - # THEN the handler should return 422 (validation error) result = app(gw_event, {}) assert result["statusCode"] == 422 From 0023b3ac15a63538abb310373865d264aba52308 Mon Sep 17 00:00:00 2001 From: dap0am Date: Thu, 21 Aug 2025 10:39:39 +0100 Subject: [PATCH 03/12] style: run make format to fix CI formatting issues --- aws_lambda_powertools/event_handler/openapi/compat.py | 1 - aws_lambda_powertools/event_handler/openapi/params.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/aws_lambda_powertools/event_handler/openapi/compat.py b/aws_lambda_powertools/event_handler/openapi/compat.py index d9c975e3396..74945748921 100644 --- a/aws_lambda_powertools/event_handler/openapi/compat.py +++ b/aws_lambda_powertools/event_handler/openapi/compat.py @@ -80,7 +80,6 @@ def type_(self) -> Any: return self.field_info.annotation def __post_init__(self) -> None: - # If the field_info.annotation is already an Annotated type with discriminator metadata, # use it directly instead of wrapping it again annotation = self.field_info.annotation diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 3743ac0eff7..0d19928020c 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -1050,7 +1050,7 @@ def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tup # This happens when using Annotated[Union[A, B], Field(discriminator='...')] with Body() has_discriminator_with_body = False powertools_annotation: FieldInfo | None = None - + if len(powertools_annotations) == 2: field_obj = None body_obj = None From ead3ee8563476190eddb066cbb1e8c34e6d2fc6b Mon Sep 17 00:00:00 2001 From: dap0am Date: Wed, 3 Sep 2025 13:30:55 +0100 Subject: [PATCH 04/12] fix(event-handler): preserve FieldInfo subclass types in copy_field_info Fix regression where copy_field_info was losing custom FieldInfo subclass types (Body, Query, etc.) by using shallow copy instead of from_annotation. This resolves the failing test_validate_embed_body_param while maintaining the discriminator functionality. --- aws_lambda_powertools/event_handler/openapi/compat.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/aws_lambda_powertools/event_handler/openapi/compat.py b/aws_lambda_powertools/event_handler/openapi/compat.py index 74945748921..af5c2d5bc87 100644 --- a/aws_lambda_powertools/event_handler/openapi/compat.py +++ b/aws_lambda_powertools/event_handler/openapi/compat.py @@ -186,7 +186,13 @@ def model_rebuild(model: type[BaseModel]) -> None: def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo: - return type(field_info).from_annotation(annotation) + # Create a shallow copy of the field_info to preserve its type and all attributes + import copy + + new_field = copy.copy(field_info) + # Update only the annotation to the new one + new_field.annotation = annotation + return new_field def get_missing_field_error(loc: tuple[str, ...]) -> dict[str, Any]: From e2c8b4986cf295256b4fbb9933ddba80ca58767c Mon Sep 17 00:00:00 2001 From: dap0am Date: Thu, 4 Sep 2025 15:54:08 +0100 Subject: [PATCH 05/12] refactor(event-handler): reduce cognitive complexity and address SonarCloud issues - Refactor get_field_info_annotated_type function by extracting helper functions to reduce cognitive complexity from 29 to below 15 - Fix copy_field_info to preserve FieldInfo subclass types using shallow copy instead of from_annotation - Rename variable Action to action_type to follow Python naming conventions - Resolve failing test_validate_embed_body_param by maintaining Body parameter type recognition - Add helper functions: _has_discriminator, _handle_discriminator_with_body, _create_field_info, _set_field_default - Maintain full backward compatibility and discriminator functionality --- .../event_handler/openapi/compat.py | 4 +- .../event_handler/openapi/params.py | 119 +++++++++++------- .../test_openapi_validation_middleware.py | 4 +- 3 files changed, 77 insertions(+), 50 deletions(-) diff --git a/aws_lambda_powertools/event_handler/openapi/compat.py b/aws_lambda_powertools/event_handler/openapi/compat.py index af5c2d5bc87..0121cd82cf7 100644 --- a/aws_lambda_powertools/event_handler/openapi/compat.py +++ b/aws_lambda_powertools/event_handler/openapi/compat.py @@ -187,9 +187,9 @@ def model_rebuild(model: type[BaseModel]) -> None: def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo: # Create a shallow copy of the field_info to preserve its type and all attributes - import copy + from copy import copy - new_field = copy.copy(field_info) + new_field = copy(field_info) # Update only the annotation to the new one new_field.annotation = annotation return new_field diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 0d19928020c..bbe522bff01 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -1037,70 +1037,97 @@ def get_field_info_response_type(annotation, value) -> tuple[FieldInfo | None, A return get_field_info_and_type_annotation(inner_type, value, False, True) +def _has_discriminator(field_info: FieldInfo) -> bool: + """Check if a FieldInfo has a discriminator.""" + return hasattr(field_info, "discriminator") and field_info.discriminator is not None + + +def _handle_discriminator_with_body( + annotations: list[FieldInfo], annotation: Any, +) -> tuple[FieldInfo | None, Any, bool]: + """ + Handle the special case of Field(discriminator) + Body() combination. + + Returns: + tuple of (powertools_annotation, type_annotation, has_discriminator_with_body) + """ + field_obj = None + body_obj = None + + for ann in annotations: + if isinstance(ann, Body): + body_obj = ann + elif _has_discriminator(ann): + field_obj = ann + + if field_obj and body_obj: + # Use Body as the primary annotation, preserve full annotation for validation + return body_obj, annotation, True + + raise AssertionError("Only one FieldInfo can be used per parameter") + + +def _create_field_info( + powertools_annotation: FieldInfo, + type_annotation: Any, + has_discriminator_with_body: bool, +) -> FieldInfo: + """Create or copy FieldInfo based on the annotation type.""" + if has_discriminator_with_body: + # For discriminator + Body case, create a new Body instance directly + field_info = Body() + field_info.annotation = type_annotation + else: + # Copy field_info because we mutate field_info.default later + field_info = copy_field_info( + field_info=powertools_annotation, + annotation=type_annotation, + ) + return field_info + + +def _set_field_default(field_info: FieldInfo, value: Any, is_path_param: bool) -> None: + """Set the default value for a field.""" + if field_info.default not in [Undefined, Required]: + raise AssertionError("FieldInfo needs to have a default value of Undefined or Required") + + if value is not inspect.Signature.empty: + if is_path_param: + raise AssertionError("Cannot use a FieldInfo as a path parameter and pass a value") + field_info.default = value + else: + field_info.default = Required + + def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tuple[FieldInfo | None, Any]: """ Get the FieldInfo and type annotation from an Annotated type. """ - field_info: FieldInfo | None = None annotated_args = get_args(annotation) type_annotation = annotated_args[0] powertools_annotations = [arg for arg in annotated_args[1:] if isinstance(arg, FieldInfo)] - # Special case: handle Field(discriminator) + Body() combination - # This happens when using Annotated[Union[A, B], Field(discriminator='...')] with Body() - has_discriminator_with_body = False + # Determine which annotation to use powertools_annotation: FieldInfo | None = None + has_discriminator_with_body = False if len(powertools_annotations) == 2: - field_obj = None - body_obj = None - for ann in powertools_annotations: - if isinstance(ann, Body): - body_obj = ann - elif isinstance(ann, FieldInfo) and hasattr(ann, "discriminator") and ann.discriminator is not None: - field_obj = ann - - if field_obj and body_obj: - # Use Body as the primary annotation - powertools_annotation = body_obj - # Preserve the full annotation including discriminator for proper validation - # This ensures the discriminator is available when creating the TypeAdapter - type_annotation = annotation - has_discriminator_with_body = True - else: - raise AssertionError("Only one FieldInfo can be used per parameter") + powertools_annotation, type_annotation, has_discriminator_with_body = _handle_discriminator_with_body( + powertools_annotations, annotation, + ) elif len(powertools_annotations) > 1: raise AssertionError("Only one FieldInfo can be used per parameter") else: powertools_annotation = next(iter(powertools_annotations), None) + # Process the annotation if it exists + field_info: FieldInfo | None = None if isinstance(powertools_annotation, FieldInfo): - if has_discriminator_with_body: - # For discriminator + Body case, create a new Body instance directly - # This avoids issues with copy_field_info trying to process the Field - field_info = Body() - field_info.annotation = type_annotation - else: - # Copy `field_info` because we mutate `field_info.default` later - # Use the possibly modified type_annotation for copy_field_info - field_info = copy_field_info( - field_info=powertools_annotation, - annotation=type_annotation, - ) - if field_info.default not in [Undefined, Required]: - raise AssertionError("FieldInfo needs to have a default value of Undefined or Required") - - if value is not inspect.Signature.empty: - if is_path_param: - raise AssertionError("Cannot use a FieldInfo as a path parameter and pass a value") - field_info.default = value - else: - field_info.default = Required + field_info = _create_field_info(powertools_annotation, type_annotation, has_discriminator_with_body) + _set_field_default(field_info, value, is_path_param) - # Preserve the full annotated type if it contains discriminator metadata - # This is crucial for tagged unions to work properly - if hasattr(powertools_annotation, "discriminator") and powertools_annotation.discriminator is not None: - # Store the full annotated type for discriminated unions + # Preserve full annotated type for discriminated unions + if _has_discriminator(powertools_annotation): type_annotation = annotation return field_info, type_annotation diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index dc55cd55772..19ab7568fab 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -1997,10 +1997,10 @@ class BarAction(BaseModel): action: Literal["bar"] bar_data: int - Action = Annotated[FooAction | BarAction, Field(discriminator="action")] + action_type = Annotated[FooAction | BarAction, Field(discriminator="action")] @app.post("/actions") - def create_action(action: Annotated[Action, Body()]): + def create_action(action: Annotated[action_type, Body()]): return {"received_action": action.action, "data": action.model_dump()} gw_event["path"] = "/actions" From 95a5eba394f76e9b134b7f7d2ffda0765d6a904d Mon Sep 17 00:00:00 2001 From: dap0am Date: Fri, 5 Sep 2025 11:38:26 +0100 Subject: [PATCH 06/12] style: fix formatting to pass CI format check Apply ruff formatting to params.py to resolve failing format check in CI --- aws_lambda_powertools/event_handler/openapi/params.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index bbe522bff01..7c81c767fc2 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -1043,7 +1043,8 @@ def _has_discriminator(field_info: FieldInfo) -> bool: def _handle_discriminator_with_body( - annotations: list[FieldInfo], annotation: Any, + annotations: list[FieldInfo], + annotation: Any, ) -> tuple[FieldInfo | None, Any, bool]: """ Handle the special case of Field(discriminator) + Body() combination. @@ -1113,7 +1114,8 @@ def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tup if len(powertools_annotations) == 2: powertools_annotation, type_annotation, has_discriminator_with_body = _handle_discriminator_with_body( - powertools_annotations, annotation, + powertools_annotations, + annotation, ) elif len(powertools_annotations) > 1: raise AssertionError("Only one FieldInfo can be used per parameter") From d839919332a2eba45976fae1699c9d7e3c449acd Mon Sep 17 00:00:00 2001 From: dap0am Date: Mon, 8 Sep 2025 11:09:41 +0100 Subject: [PATCH 07/12] fix: resolve mypy type error in _create_field_info function Add explicit type annotation for field_info variable to fix mypy error about incompatible types between FieldInfo and Body. This ensures type checking passes across all Python versions (3.9-3.13). --- aws_lambda_powertools/event_handler/openapi/params.py | 1 + 1 file changed, 1 insertion(+) diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 7c81c767fc2..9f205601066 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -1074,6 +1074,7 @@ def _create_field_info( has_discriminator_with_body: bool, ) -> FieldInfo: """Create or copy FieldInfo based on the annotation type.""" + field_info: FieldInfo if has_discriminator_with_body: # For discriminator + Body case, create a new Body instance directly field_info = Body() From 5762c94808ff0aeb5c5235951af2b8e5f148e22e Mon Sep 17 00:00:00 2001 From: dap0am Date: Mon, 8 Sep 2025 17:08:05 +0100 Subject: [PATCH 08/12] fix: use Union syntax for Python 3.9 compatibility --- .../_pydantic/test_openapi_validation_middleware.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index 19ab7568fab..954a0514dc8 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from enum import Enum from pathlib import PurePath -from typing import List, Literal, Optional, Tuple +from typing import List, Literal, Optional, Tuple, Union import pytest from pydantic import BaseModel, Field @@ -1997,7 +1997,7 @@ class BarAction(BaseModel): action: Literal["bar"] bar_data: int - action_type = Annotated[FooAction | BarAction, Field(discriminator="action")] + action_type = Annotated[Union[FooAction, BarAction], Field(discriminator="action")] @app.post("/actions") def create_action(action: Annotated[action_type, Body()]): From 793a09790bcfab0606416d5b8d325aec2b5a1324 Mon Sep 17 00:00:00 2001 From: dap0am Date: Tue, 9 Sep 2025 13:03:12 +0100 Subject: [PATCH 09/12] feat(event-handler): add documentation and example for Field discriminator support --- .../event_handler/openapi/compat.py | 3 +- docs/core/event_handler/api_gateway.md | 17 +++++++ .../src/discriminated_unions.py | 47 +++++++++++++++++++ 3 files changed, 65 insertions(+), 2 deletions(-) create mode 100644 examples/event_handler_rest/src/discriminated_unions.py diff --git a/aws_lambda_powertools/event_handler/openapi/compat.py b/aws_lambda_powertools/event_handler/openapi/compat.py index 0121cd82cf7..8223f2ed1af 100644 --- a/aws_lambda_powertools/event_handler/openapi/compat.py +++ b/aws_lambda_powertools/event_handler/openapi/compat.py @@ -3,6 +3,7 @@ from collections import deque from collections.abc import Mapping, Sequence +from copy import copy # MAINTENANCE: remove when deprecating Pydantic v1. Mypy doesn't handle two different code paths that import different # versions of a module, so we need to ignore errors here. @@ -187,8 +188,6 @@ def model_rebuild(model: type[BaseModel]) -> None: def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo: # Create a shallow copy of the field_info to preserve its type and all attributes - from copy import copy - new_field = copy(field_info) # Update only the annotation to the new one new_field.annotation = annotation diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index af0600a9f22..80a54302480 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -568,6 +568,23 @@ You can use the `Form` type to tell the Event Handler that a parameter expects f --8<-- "examples/event_handler_rest/src/working_with_form_data.py" ``` +#### Discriminated unions + +!!! info "You must set `enable_validation=True` to use discriminated unions via type annotation." + +You can use Pydantic's `Field(discriminator="...")` with union types to create discriminated unions (also known as tagged unions). This allows the Event Handler to automatically determine which model to use based on a discriminator field in the request body. + +In the following example, we define two action types (`FooAction` and `BarAction`) that share a common discriminator field `action`. The Event Handler will automatically parse the request body and instantiate the correct model based on the `action` field value: + +```python hl_lines="3 4 8 31 36" title="discriminated_unions.py" +--8<-- "examples/event_handler_rest/src/discriminated_unions.py" +``` + +1. `Field(discriminator="action")` tells Pydantic to use the `action` field to determine which model to instantiate +2. `Body()` annotation tells the Event Handler to parse the request body using the discriminated union + +When you send a request with `{"action": "foo", "foo_data": "example"}`, the Event Handler will automatically create a `FooAction` instance. Similarly, `{"action": "bar", "bar_data": 42}` will create a `BarAction` instance. + #### Supported types for response serialization With data validation enabled, we natively support serializing the following data types to JSON: diff --git a/examples/event_handler_rest/src/discriminated_unions.py b/examples/event_handler_rest/src/discriminated_unions.py new file mode 100644 index 00000000000..5e25eae243b --- /dev/null +++ b/examples/event_handler_rest/src/discriminated_unions.py @@ -0,0 +1,47 @@ +from typing import Literal, Union + +from pydantic import BaseModel, Field +from typing_extensions import Annotated + +from aws_lambda_powertools import Logger, Tracer +from aws_lambda_powertools.event_handler import APIGatewayRestResolver +from aws_lambda_powertools.event_handler.openapi.params import Body +from aws_lambda_powertools.logging import correlation_paths +from aws_lambda_powertools.utilities.typing import LambdaContext + +tracer = Tracer() +logger = Logger() +app = APIGatewayRestResolver(enable_validation=True) + + +class FooAction(BaseModel): + """Action type for foo operations.""" + + action: Literal["foo"] = "foo" + foo_data: str + + +class BarAction(BaseModel): + """Action type for bar operations.""" + + action: Literal["bar"] = "bar" + bar_data: int + + +ActionType = Annotated[Union[FooAction, BarAction], Field(discriminator="action")] # (1)! + + +@app.post("/actions") +@tracer.capture_method +def handle_action(action: Annotated[ActionType, Body(description="Action to perform")]): # (2)! + """Handle different action types using discriminated unions.""" + if isinstance(action, FooAction): + return {"message": f"Handling foo action with data: {action.foo_data}"} + elif isinstance(action, BarAction): + return {"message": f"Handling bar action with data: {action.bar_data}"} + + +@logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_HTTP) +@tracer.capture_lambda_handler +def lambda_handler(event: dict, context: LambdaContext) -> dict: + return app.resolve(event, context) \ No newline at end of file From 4400181bc3ca0c86991761de8647a72648eb4afd Mon Sep 17 00:00:00 2001 From: dap0am Date: Tue, 9 Sep 2025 13:19:21 +0100 Subject: [PATCH 10/12] style: run make format to fix CI formatting issues --- examples/event_handler_rest/src/discriminated_unions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/event_handler_rest/src/discriminated_unions.py b/examples/event_handler_rest/src/discriminated_unions.py index 5e25eae243b..dec2104b0b0 100644 --- a/examples/event_handler_rest/src/discriminated_unions.py +++ b/examples/event_handler_rest/src/discriminated_unions.py @@ -44,4 +44,4 @@ def handle_action(action: Annotated[ActionType, Body(description="Action to perf @logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_HTTP) @tracer.capture_lambda_handler def lambda_handler(event: dict, context: LambdaContext) -> dict: - return app.resolve(event, context) \ No newline at end of file + return app.resolve(event, context) From 587d2fad9fa86d5757e44ed368d6f54294f97461 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Thu, 11 Sep 2025 15:10:22 +0100 Subject: [PATCH 11/12] small changes --- .../event_handler/openapi/compat.py | 3 -- .../event_handler/openapi/params.py | 8 ++--- docs/core/event_handler/api_gateway.md | 28 ++++++--------- .../handlers/data_validation_with_fields.py | 34 +++++++++++++++++++ tests/e2e/event_handler/infrastructure.py | 4 +++ tests/e2e/event_handler/test_openapi.py | 17 ++++++++++ .../test_openapi_validation_middleware.py | 32 ----------------- 7 files changed, 70 insertions(+), 56 deletions(-) create mode 100644 tests/e2e/event_handler/handlers/data_validation_with_fields.py diff --git a/aws_lambda_powertools/event_handler/openapi/compat.py b/aws_lambda_powertools/event_handler/openapi/compat.py index 8223f2ed1af..6b2c691442f 100644 --- a/aws_lambda_powertools/event_handler/openapi/compat.py +++ b/aws_lambda_powertools/event_handler/openapi/compat.py @@ -4,9 +4,6 @@ from collections import deque from collections.abc import Mapping, Sequence from copy import copy - -# MAINTENANCE: remove when deprecating Pydantic v1. Mypy doesn't handle two different code paths that import different -# versions of a module, so we need to ignore errors here. from dataclasses import dataclass, is_dataclass from typing import TYPE_CHECKING, Any, Deque, FrozenSet, List, Set, Tuple, Union diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 9f205601066..454dce685a7 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -1042,7 +1042,7 @@ def _has_discriminator(field_info: FieldInfo) -> bool: return hasattr(field_info, "discriminator") and field_info.discriminator is not None -def _handle_discriminator_with_body( +def _handle_discriminator_with_param( annotations: list[FieldInfo], annotation: Any, ) -> tuple[FieldInfo | None, Any, bool]: @@ -1111,10 +1111,10 @@ def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tup # Determine which annotation to use powertools_annotation: FieldInfo | None = None - has_discriminator_with_body = False + has_discriminator_with_param = False if len(powertools_annotations) == 2: - powertools_annotation, type_annotation, has_discriminator_with_body = _handle_discriminator_with_body( + powertools_annotation, type_annotation, has_discriminator_with_param = _handle_discriminator_with_param( powertools_annotations, annotation, ) @@ -1126,7 +1126,7 @@ def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tup # Process the annotation if it exists field_info: FieldInfo | None = None if isinstance(powertools_annotation, FieldInfo): - field_info = _create_field_info(powertools_annotation, type_annotation, has_discriminator_with_body) + field_info = _create_field_info(powertools_annotation, type_annotation, has_discriminator_with_param) _set_field_default(field_info, value, is_path_param) # Preserve full annotated type for discriminated unions diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index 80a54302480..77858766f16 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -428,6 +428,17 @@ We use the `Annotated` and OpenAPI `Body` type to instruct Event Handler that ou --8<-- "examples/event_handler_rest/src/validating_payload_subset_output.json" ``` +##### Discriminated unions + +You can use Pydantic's `Field(discriminator="...")` with union types to create discriminated unions (also known as tagged unions). This allows the Event Handler to automatically determine which model to use based on a discriminator field in the request body. + +```python hl_lines="3 4 8 31 36" title="discriminated_unions.py" +--8<-- "examples/event_handler_rest/src/discriminated_unions.py" +``` + +1. `Field(discriminator="action")` tells Pydantic to use the `action` field to determine which model to instantiate +2. `Body()` annotation tells the Event Handler to parse the request body using the discriminated union + #### Validating responses You can use `response_validation_error_http_code` to set a custom HTTP code for failed response validation. When this field is set, we will raise a `ResponseValidationError` instead of a `RequestValidationError`. @@ -568,23 +579,6 @@ You can use the `Form` type to tell the Event Handler that a parameter expects f --8<-- "examples/event_handler_rest/src/working_with_form_data.py" ``` -#### Discriminated unions - -!!! info "You must set `enable_validation=True` to use discriminated unions via type annotation." - -You can use Pydantic's `Field(discriminator="...")` with union types to create discriminated unions (also known as tagged unions). This allows the Event Handler to automatically determine which model to use based on a discriminator field in the request body. - -In the following example, we define two action types (`FooAction` and `BarAction`) that share a common discriminator field `action`. The Event Handler will automatically parse the request body and instantiate the correct model based on the `action` field value: - -```python hl_lines="3 4 8 31 36" title="discriminated_unions.py" ---8<-- "examples/event_handler_rest/src/discriminated_unions.py" -``` - -1. `Field(discriminator="action")` tells Pydantic to use the `action` field to determine which model to instantiate -2. `Body()` annotation tells the Event Handler to parse the request body using the discriminated union - -When you send a request with `{"action": "foo", "foo_data": "example"}`, the Event Handler will automatically create a `FooAction` instance. Similarly, `{"action": "bar", "bar_data": 42}` will create a `BarAction` instance. - #### Supported types for response serialization With data validation enabled, we natively support serializing the following data types to JSON: diff --git a/tests/e2e/event_handler/handlers/data_validation_with_fields.py b/tests/e2e/event_handler/handlers/data_validation_with_fields.py new file mode 100644 index 00000000000..64ddacdef64 --- /dev/null +++ b/tests/e2e/event_handler/handlers/data_validation_with_fields.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from typing import Annotated, Literal + +from pydantic import BaseModel, Field + +from aws_lambda_powertools.event_handler import APIGatewayRestResolver +from aws_lambda_powertools.event_handler.openapi.params import Body + +app = APIGatewayRestResolver(enable_validation=True) +app.enable_swagger() + + +class FooAction(BaseModel): + action: Literal["foo"] + foo_data: str + + +class BarAction(BaseModel): + action: Literal["bar"] + bar_data: int + + +Action = Annotated[FooAction | BarAction, Field(discriminator="action")] + + +@app.post("/data_validation_with_fields") +def create_action(action: Annotated[Action, Body(discriminator="action")]): + return {"message": "Powertools e2e API"} + + +def lambda_handler(event, context): + print(event) + return app.resolve(event, context) diff --git a/tests/e2e/event_handler/infrastructure.py b/tests/e2e/event_handler/infrastructure.py index 46f7cfe2473..6e92f0d4e21 100644 --- a/tests/e2e/event_handler/infrastructure.py +++ b/tests/e2e/event_handler/infrastructure.py @@ -24,6 +24,7 @@ def create_resources(self): functions["OpenapiHandler"], functions["OpenapiHandlerWithPep563"], functions["DataValidationAndMiddleware"], + functions["DataValidationWithFields"], ], ) self._create_api_gateway_http(function=functions["ApiGatewayHttpHandler"]) @@ -105,6 +106,9 @@ def _create_api_gateway_rest(self, function: list[Function]): openapi_schema = apigw.root.add_resource("data_validation_middleware") openapi_schema.add_method("GET", apigwv1.LambdaIntegration(function[3], proxy=True)) + openapi_schema = apigw.root.add_resource("data_validation_with_fields") + openapi_schema.add_method("POST", apigwv1.LambdaIntegration(function[4], proxy=True)) + CfnOutput(self.stack, "APIGatewayRestUrl", value=apigw.url) def _create_lambda_function_url(self, function: Function): diff --git a/tests/e2e/event_handler/test_openapi.py b/tests/e2e/event_handler/test_openapi.py index b5255e44661..56bf9dfc9ed 100644 --- a/tests/e2e/event_handler/test_openapi.py +++ b/tests/e2e/event_handler/test_openapi.py @@ -59,3 +59,20 @@ def test_get_openapi_validation_and_middleware(apigw_rest_endpoint): ) assert response.status_code == 202 + + +def test_openapi_with_fields_discriminator(apigw_rest_endpoint): + # GIVEN + url = f"{apigw_rest_endpoint}data_validation_with_fields" + + # WHEN + response = data_fetcher.get_http_response( + Request( + method="POST", + url=url, + json={"action": "foo", "foo_data": "foo data working"}, + ), + ) + + assert "Powertools e2e API" in response.text + assert response.status_code == 200 diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index 954a0514dc8..de1add78d55 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -2030,35 +2030,3 @@ def create_action(action: Annotated[action_type, Body()]): result = app(gw_event, {}) assert result["statusCode"] == 422 - - -def test_field_other_features_still_work(gw_event): - """Test that other Field features still work after discriminator fix""" - app = APIGatewayRestResolver(enable_validation=True) - - class UserInput(BaseModel): - name: Annotated[str, Field(min_length=2, max_length=50, description="User name")] - age: Annotated[int, Field(ge=18, le=120, description="User age")] - email: Annotated[str, Field(pattern=r".+@.+\..+", description="User email")] - - @app.post("/users") - def create_user(user: UserInput): - return {"created": user.model_dump()} - - gw_event["path"] = "/users" - gw_event["httpMethod"] = "POST" - gw_event["headers"]["content-type"] = "application/json" - gw_event["body"] = '{"name": "John", "age": 25, "email": "john@example.com"}' - - result = app(gw_event, {}) - assert result["statusCode"] == 200 - - response_body = json.loads(result["body"]) - assert response_body["created"]["name"] == "John" - assert response_body["created"]["age"] == 25 - assert response_body["created"]["email"] == "john@example.com" - - gw_event["body"] = '{"name": "John", "age": 16, "email": "john@example.com"}' - - result = app(gw_event, {}) - assert result["statusCode"] == 422 From 851992fef22640f6e8c8086e013da3f4bf08087b Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Thu, 11 Sep 2025 16:10:51 +0100 Subject: [PATCH 12/12] small changes --- aws_lambda_powertools/event_handler/openapi/params.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 454dce685a7..1919fb9fe77 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -1125,13 +1125,13 @@ def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tup # Process the annotation if it exists field_info: FieldInfo | None = None - if isinstance(powertools_annotation, FieldInfo): + if isinstance(powertools_annotation, FieldInfo): # pragma: no cover field_info = _create_field_info(powertools_annotation, type_annotation, has_discriminator_with_param) _set_field_default(field_info, value, is_path_param) # Preserve full annotated type for discriminated unions - if _has_discriminator(powertools_annotation): - type_annotation = annotation + if _has_discriminator(powertools_annotation): # pragma: no cover + type_annotation = annotation # pragma: no cover return field_info, type_annotation