|
2 | 2 | import types |
3 | 3 | from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine, Generator |
4 | 4 | from functools import wraps |
5 | | -from typing import TYPE_CHECKING, Annotated, Any, ParamSpec, TypeVar, cast, get_origin, overload |
| 5 | +from typing import TYPE_CHECKING, Annotated, Any, ParamSpec, TypeVar, Union, cast, get_args, get_origin, overload |
6 | 6 |
|
7 | 7 | import fastapi |
8 | 8 | import fastapi.params |
@@ -41,9 +41,22 @@ def _override_func_dependency_signature(func: Callable[P, T] | Callable[P, Await |
41 | 41 | fastapi_default = metadata |
42 | 42 | break |
43 | 43 | if fastapi_default: |
| 44 | + actual_type = get_args(param.annotation)[0] |
| 45 | + origin = get_origin(actual_type) |
| 46 | + |
| 47 | + base_for_class = actual_type |
| 48 | + if origin is Union or origin is types.UnionType: |
| 49 | + union_args = get_args(actual_type) |
| 50 | + base_for_class = next( |
| 51 | + (t for t in union_args if t is not type(None)), |
| 52 | + union_args[0] if union_args else object, |
| 53 | + ) |
| 54 | + |
| 55 | + base_class = get_origin(base_for_class) or base_for_class |
| 56 | + |
44 | 57 | dynamic_default = types.new_class( |
45 | | - "Injected_" + param.annotation.__origin__.__name__, |
46 | | - (param.annotation.__origin__,), |
| 58 | + "Injected_" + getattr(base_class, "__name__", "Injected"), |
| 59 | + (base_class,), |
47 | 60 | {}, |
48 | 61 | lambda ns: ns.update({"__init__": lambda self, *args, **kwargs: None}), # noqa: ARG005 |
49 | 62 | ) |
|
0 commit comments