Skip to content

Commit 858f714

Browse files
authored
fix: handle Optional and Union types in Depends in injectable (#157)
1 parent 8e05123 commit 858f714

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

src/fastapi_injectable/decorator.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import types
33
from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine, Generator
44
from functools import wraps
5-
from typing import TYPE_CHECKING, Annotated, Any, ParamSpec, TypeVar, cast, get_origin, overload
5+
from typing import TYPE_CHECKING, Annotated, Any, ParamSpec, TypeVar, Union, cast, get_args, get_origin, overload
66

77
import fastapi
88
import fastapi.params
@@ -41,9 +41,22 @@ def _override_func_dependency_signature(func: Callable[P, T] | Callable[P, Await
4141
fastapi_default = metadata
4242
break
4343
if fastapi_default:
44+
actual_type = get_args(param.annotation)[0]
45+
origin = get_origin(actual_type)
46+
47+
base_for_class = actual_type
48+
if origin is Union or origin is types.UnionType:
49+
union_args = get_args(actual_type)
50+
base_for_class = next(
51+
(t for t in union_args if t is not type(None)),
52+
union_args[0] if union_args else object,
53+
)
54+
55+
base_class = get_origin(base_for_class) or base_for_class
56+
4457
dynamic_default = types.new_class(
45-
"Injected_" + param.annotation.__origin__.__name__,
46-
(param.annotation.__origin__,),
58+
"Injected_" + getattr(base_class, "__name__", "Injected"),
59+
(base_class,),
4760
{},
4861
lambda ns: ns.update({"__init__": lambda self, *args, **kwargs: None}), # noqa: ARG005
4962
)

test/test_injectable.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,3 +395,17 @@ def get_country(capital: Annotated[Capital, Depends(get_capital)]) -> Country:
395395
country_2 = country
396396
assert country_1.capital is country_2.capital
397397
assert country_1.capital.mayor is country_2.capital.mayor
398+
399+
400+
def test_injectable_converts_depends_with_optional_uniontype_types() -> None:
401+
async def get_mayor() -> Mayor | None:
402+
return None
403+
404+
@injectable(use_cache=True)
405+
def get_capital(mayor: Annotated[Mayor | None, Depends(get_mayor)]) -> Capital | None:
406+
return Capital(mayor) if mayor else None
407+
408+
sig = signature(get_capital)
409+
param = next(iter(sig.parameters.values()))
410+
411+
assert type(param.default).__name__.startswith("Injected")

0 commit comments

Comments
 (0)