11import functools
22import inspect
33import logging
4+ import sys
45import os
5- from typing import Callable , Optional
6+ from typing import Any , Callable , Dict , Optional , Union , cast , overload
7+
8+ if sys .version_info >= (3 , 8 ):
9+ from typing import Protocol
10+ else :
11+ from typing_extensions import Protocol
612
713from ..shared import constants
814from ..shared .functions import resolve_truthy_env_var_choice
915from ..tracing import Tracer
16+ from ..utilities .typing import LambdaContext
1017from .exceptions import MiddlewareInvalidArgumentError
1118
1219logger = logging .getLogger (__name__ )
1320
21+ # context: Any to avoid forcing users to type it as context: LambdaContext
22+ _Handler = Callable [[Any , LambdaContext ], Any ]
23+ _RawHandlerDecorator = Callable [[_Handler ], _Handler ]
24+
25+
26+ class _FactoryDecorator (Protocol ):
27+ # it'd be better for this to be using ParamSpec (available from 3.10)
28+ def __call__ (
29+ self , handler : _Handler , event : Dict [str , Any ], context : LambdaContext , ** kwargs : Any
30+ ) -> _RawHandlerDecorator :
31+ ...
32+
33+
34+ class _HandlerDecorator (Protocol ):
35+ @overload
36+ def __call__ (self , decorator : _Handler ) -> _Handler :
37+ ...
38+
39+ @overload
40+ def __call__ (self , decorator : None = None , ** kwargs : Any ) -> _RawHandlerDecorator :
41+ ...
42+
43+ def __call__ (self , decorator : Optional [_Handler ] = None , ** kwargs : Any ) -> Union [_Handler , _RawHandlerDecorator ]:
44+ ...
45+
46+
47+ @overload
48+ def lambda_handler_decorator (decorator : _FactoryDecorator ) -> _HandlerDecorator :
49+ ...
50+
51+
52+ @overload
53+ def lambda_handler_decorator (
54+ decorator : None = None , trace_execution : Optional [bool ] = None
55+ ) -> Callable [[_FactoryDecorator ], _HandlerDecorator ]:
56+ ...
57+
1458
15- def lambda_handler_decorator (decorator : Optional [Callable ] = None , trace_execution : Optional [bool ] = None ):
59+ def lambda_handler_decorator (
60+ decorator : Optional [_FactoryDecorator ] = None , trace_execution : Optional [bool ] = None
61+ ) -> Union [_HandlerDecorator , Callable [[_FactoryDecorator ], _HandlerDecorator ]]:
1662 """Decorator factory for decorating Lambda handlers.
1763
1864 You can use lambda_handler_decorator to create your own middlewares,
@@ -103,19 +149,25 @@ def lambda_handler(event, context):
103149 """
104150
105151 if decorator is None :
106- return functools .partial (lambda_handler_decorator , trace_execution = trace_execution )
152+ return cast (
153+ Callable [[_FactoryDecorator ], _HandlerDecorator ],
154+ functools .partial (lambda_handler_decorator , trace_execution = trace_execution ),
155+ )
107156
108157 trace_execution = resolve_truthy_env_var_choice (
109158 env = os .getenv (constants .MIDDLEWARE_FACTORY_TRACE_ENV , "false" ), choice = trace_execution
110159 )
111160
112161 @functools .wraps (decorator )
113- def final_decorator (func : Optional [Callable ] = None , ** kwargs ):
162+ def final_decorator (
163+ func : Optional [_RawHandlerDecorator ] = None , ** kwargs : Any
164+ ) -> Union [_Handler , _RawHandlerDecorator ]:
114165 # If called with kwargs return new func with kwargs
115166 if func is None :
116167 return functools .partial (final_decorator , ** kwargs )
117168
118169 if not inspect .isfunction (func ):
170+ assert decorator is not None
119171 # @custom_middleware(True) vs @custom_middleware(log_event=True)
120172 raise MiddlewareInvalidArgumentError (
121173 f"Only keyword arguments is supported for middlewares: { decorator .__qualname__ } received { func } " # type: ignore # noqa: E501
@@ -138,4 +190,4 @@ def wrapper(event, context):
138190
139191 return wrapper
140192
141- return final_decorator
193+ return cast ( _HandlerDecorator , final_decorator )
0 commit comments