Skip to content

Commit 6484266

Browse files
committed
fix: respect FastAPI dependency_overrides in get_injected_obj
This commit fixes a bug where `get_injected_obj` did not correctly respect FastAPI's `app.dependency_overrides` for top-level dependencies. The issue was that `get_injected_obj` was not explicitly integrating with FastAPI's dependency injection system when called outside of a route handler. This meant that when `app.dependency_overrides` was used to mock a dependency for testing, `get_injected_obj` would bypass the override and attempt to resolve the original dependency. To fix this, `get_injected_obj` now wraps the provided function in a pass-through dependency using `Annotated` and `Depends`. This makes the dependency explicitly visible to FastAPI, allowing `dependency_overrides` to function as expected. This change ensures that `get_injected_obj` can be reliably used in testing scenarios with mocked dependencies.
1 parent 858f714 commit 6484266

File tree

7 files changed

+400
-135
lines changed

7 files changed

+400
-135
lines changed

noxfile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def mypy(session: Session) -> None:
134134
def tests(session: Session) -> None:
135135
"""Run the test suite."""
136136
session.install(".")
137-
session.install("coverage[toml]", "pytest", "pytest-asyncio")
137+
session.install("coverage[toml]", "pytest", "pytest-asyncio", "httpx")
138138
try:
139139
session.run("coverage", "run", "--parallel-mode", "-m", "pytest", *session.posargs)
140140
finally:

poetry.lock

Lines changed: 56 additions & 87 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ ipython = "^8.30.0"
3939
ipdb = "^0.13.13"
4040
furo = "^2024.8.6"
4141
coverage = { extras = ["toml"], version = "^7.6.9" }
42+
httpx = "^0.28.1"
4243

4344
[tool.ruff]
4445
lint.ignore = [

src/fastapi_injectable/decorator.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,15 @@ def _override_func_dependency_signature(func: Callable[P, T] | Callable[P, Await
5858
"Injected_" + getattr(base_class, "__name__", "Injected"),
5959
(base_class,),
6060
{},
61-
lambda ns: ns.update({"__init__": lambda self, *args, **kwargs: None}), # noqa: ARG005
61+
lambda ns: ns.update(
62+
{
63+
"__init__": lambda self, *args, **kwargs: None, # noqa: ARG005
64+
**{
65+
method: lambda *args, **kwargs: None # noqa: ARG005
66+
for method in getattr(base_class, "__abstractmethods__", []) # noqa: B023
67+
},
68+
}
69+
),
6270
)
6371
parameter = inspect.Parameter.replace(param, default=dynamic_default())
6472
new_parameters.append(parameter)

src/fastapi_injectable/util.py

Lines changed: 82 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,85 @@
11
import atexit
22
import inspect
33
import signal
4-
from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine, Generator, Sequence
5-
from typing import Any, ParamSpec, TypeVar, cast, overload
4+
from collections.abc import AsyncGenerator, Awaitable, Callable, Generator, Sequence
5+
from typing import Annotated, Any, ParamSpec, TypeVar, cast, get_type_hints, overload
6+
7+
from fastapi import Depends
68

79
from .async_exit_stack import async_exit_stack_manager
810
from .cache import dependency_cache
911
from .concurrency import run_coroutine_sync
1012
from .decorator import injectable
1113

1214
T = TypeVar("T")
15+
T2 = TypeVar("T2")
16+
T3 = TypeVar("T3")
1317
P = ParamSpec("P")
1418

19+
PROVIDER_TO_WRAPPER_FUNC_MAP: dict[Callable[..., Any], list[Callable[[Any], Any]]] = {}
20+
21+
22+
def _create_depends_function(
23+
provider: Callable[..., Any],
24+
) -> Callable[..., Any]:
25+
"""Build a pass-through dependency for FastAPI.
26+
27+
Related issue: https://github.com/JasperSui/fastapi-injectable/issues/153
28+
29+
The returned callable has a single parameter whose annotation is:
30+
Annotated[<provider_return_type>, Depends(provider)]
31+
and it simply returns that parameter.
32+
33+
Type checkers see this as Callable[[T], T] with T inferred from the provider's return type.
34+
35+
Raises:
36+
TypeError if the provider's return type cannot be determined.
37+
"""
38+
# Runtime: resolve the *concrete* return type for FastAPI's inspection
39+
try:
40+
hints = get_type_hints(provider, include_extras=True)
41+
except Exception: # pragma: no cover # noqa: BLE001
42+
hints = {}
43+
44+
rt = hints.get("return", inspect.Signature.empty)
45+
46+
if rt in (inspect.Signature.empty, Any, None): # pragma: no cover
47+
msg = (
48+
f"Cannot infer return type for provider {getattr(provider, '__name__', repr(provider))}. "
49+
"Please add an explicit return annotation."
50+
)
51+
raise TypeError(msg)
52+
53+
def inner(dep: T2) -> T2:
54+
return dep
55+
56+
# Provide the annotations FastAPI inspects at runtime
57+
inner.__annotations__ = {
58+
"dep": Annotated[rt, Depends(provider)],
59+
"return": rt,
60+
}
61+
62+
# Nice signature for docs/inspection
63+
inner.__signature__ = inspect.Signature( # type: ignore[attr-defined]
64+
parameters=[
65+
inspect.Parameter(
66+
"dep",
67+
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
68+
annotation=Annotated[rt, Depends(provider)],
69+
)
70+
],
71+
return_annotation=rt,
72+
)
73+
74+
inner.__name__ = f"{getattr(provider, '__name__', 'provider')}_extractor"
75+
76+
# Store the wrapper function for cleanup the provider later
77+
if provider not in PROVIDER_TO_WRAPPER_FUNC_MAP:
78+
PROVIDER_TO_WRAPPER_FUNC_MAP[provider] = []
79+
PROVIDER_TO_WRAPPER_FUNC_MAP[provider].append(inner)
80+
81+
return inner
82+
1583

1684
@overload
1785
def get_injected_obj(
@@ -111,30 +179,22 @@ def get_db() -> Generator[Database, None, None]:
111179
- Cleanup code in generators will be executed when calling cleanup functions
112180
- Uses FastAPI's dependency injection system under the hood
113181
"""
114-
injectable_func = injectable(func, use_cache=use_cache)
115-
116182
if args is None:
117183
args = []
118184
if kwargs is None:
119185
kwargs = {}
120186

121-
if inspect.isasyncgenfunction(func):
122-
# Handle async generator
123-
async_gen = cast(AsyncGenerator[T, Any], injectable_func(*args, **kwargs))
124-
return run_coroutine_sync(anext(async_gen))
125-
126-
if inspect.isgeneratorfunction(func):
127-
# Handle sync generator
128-
gen = cast(Generator[T, Any, Any], injectable_func(*args, **kwargs))
129-
return next(gen)
130-
131-
if inspect.iscoroutinefunction(func):
132-
# Handle coroutine
133-
coro = cast(Coroutine[Any, Any, T], injectable_func(*args, **kwargs))
134-
return run_coroutine_sync(coro)
187+
wrapped_func = _create_depends_function(func)
188+
injectable_func = injectable(wrapped_func, use_cache=use_cache)
189+
result = injectable_func(*args, **kwargs) # type: ignore[no-untyped-call]
135190

136-
# Handle regular function
137-
return cast(T, injectable_func(*args, **kwargs))
191+
if inspect.isasyncgen(result):
192+
return cast("T", run_coroutine_sync(anext(result)))
193+
if inspect.isgenerator(result):
194+
return cast("T", next(result))
195+
if inspect.isawaitable(result):
196+
return cast("T", run_coroutine_sync(result)) # type: ignore[arg-type]
197+
return cast("T", result)
138198

139199

140200
async def cleanup_exit_stack_of_func(func: Callable[..., Any], *, raise_exception: bool = False) -> None:
@@ -152,7 +212,8 @@ async def cleanup_exit_stack_of_func(func: Callable[..., Any], *, raise_exceptio
152212
Raises:
153213
DependencyCleanupError: When cleanup fails and raise_exception is True
154214
"""
155-
await async_exit_stack_manager.cleanup_stack(func, raise_exception=raise_exception)
215+
for wrapper in PROVIDER_TO_WRAPPER_FUNC_MAP.get(func, [func]):
216+
await async_exit_stack_manager.cleanup_stack(wrapper, raise_exception=raise_exception)
156217

157218

158219
async def cleanup_all_exit_stacks(*, raise_exception: bool = False) -> None:

test/test_integration.py

Lines changed: 161 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
from collections.abc import AsyncGenerator, Generator
22
from typing import Annotated, Any
3+
from unittest.mock import Mock
34

45
import pytest
5-
from fastapi import Depends
6+
from fastapi import Depends, FastAPI
7+
from fastapi.testclient import TestClient
68

7-
from src.fastapi_injectable.concurrency import run_coroutine_sync
9+
from src.fastapi_injectable import register_app
10+
from src.fastapi_injectable.concurrency import loop_manager, run_coroutine_sync
811
from src.fastapi_injectable.decorator import injectable
9-
from src.fastapi_injectable.util import cleanup_all_exit_stacks, cleanup_exit_stack_of_func, get_injected_obj
12+
from src.fastapi_injectable.util import (
13+
cleanup_all_exit_stacks,
14+
cleanup_exit_stack_of_func,
15+
get_injected_obj,
16+
)
1017

1118

1219
@pytest.fixture
@@ -604,3 +611,154 @@ def get_country(
604611

605612
country: Country = get_country(basic_str="basic_str", basic_int=1, basic_bool=True, basic_dict={"key": "value"}) # type: ignore # noqa: PGH003
606613
assert country is not None
614+
615+
616+
def test_get_injected_obj_with_dependency_override_sync(clean_exit_stack_manager: None) -> None:
617+
"""Tests that get_injected_obj respects dependency_overrides for sync dependencies."""
618+
619+
def sync_dependency_override() -> int:
620+
return 1
621+
622+
def use_sync_dependency_override() -> int:
623+
return get_injected_obj(sync_dependency_override)
624+
625+
loop_manager.set_loop_strategy(
626+
"background_thread"
627+
) # To avoid affecting the FastAPI app and httpx client event loop
628+
app = FastAPI()
629+
630+
@app.get("/")
631+
def read_root() -> int:
632+
return use_sync_dependency_override()
633+
634+
mock_dependency = Mock(return_value=2)
635+
app.dependency_overrides[sync_dependency_override] = lambda: mock_dependency()
636+
run_coroutine_sync(register_app(app))
637+
638+
client = TestClient(app)
639+
response = client.get("/")
640+
assert response.status_code == 200
641+
assert response.json() == 2
642+
mock_dependency.assert_called_once()
643+
644+
645+
@pytest.mark.asyncio
646+
async def test_get_injected_obj_with_dependency_override_async(clean_exit_stack_manager: None) -> None:
647+
"""Tests that get_injected_obj respects dependency_overrides for async dependencies."""
648+
649+
async def async_dependency_override() -> int:
650+
return 1
651+
652+
def use_async_dependency_override() -> int:
653+
return get_injected_obj(async_dependency_override)
654+
655+
loop_manager.set_loop_strategy(
656+
"background_thread"
657+
) # To avoid affecting the FastAPI app and httpx client event loop
658+
app = FastAPI()
659+
660+
@app.get("/")
661+
async def read_root() -> int:
662+
return use_async_dependency_override()
663+
664+
mock_dependency = Mock(return_value=2)
665+
app.dependency_overrides[async_dependency_override] = lambda: mock_dependency()
666+
await register_app(app)
667+
668+
client = TestClient(app)
669+
response = client.get("/")
670+
assert response.status_code == 200
671+
assert response.json() == 2
672+
mock_dependency.assert_called_once()
673+
674+
675+
@pytest.mark.asyncio
676+
async def test_get_injected_obj_with_dependency_override_sync_generator(
677+
clean_exit_stack_manager: None,
678+
) -> None:
679+
"""Tests that get_injected_obj respects dependency_overrides for sync generators."""
680+
sync_cleanup_mock_override = Mock()
681+
682+
def sync_gen_dependency_override() -> Generator[int, None, None]:
683+
try:
684+
yield 1
685+
finally:
686+
sync_cleanup_mock_override()
687+
688+
def use_sync_gen_dependency_override() -> int:
689+
return get_injected_obj(sync_gen_dependency_override)
690+
691+
override_sync_cleanup_mock = Mock()
692+
693+
def override_sync_gen() -> Generator[int, None, None]:
694+
try:
695+
yield 2
696+
finally:
697+
override_sync_cleanup_mock()
698+
699+
loop_manager.set_loop_strategy(
700+
"background_thread"
701+
) # To avoid affecting the FastAPI app and httpx client event loop
702+
app = FastAPI()
703+
704+
@app.get("/")
705+
def read_root() -> int:
706+
return use_sync_gen_dependency_override()
707+
708+
app.dependency_overrides[sync_gen_dependency_override] = override_sync_gen
709+
await register_app(app)
710+
711+
with TestClient(app) as client:
712+
response = client.get("/")
713+
assert response.status_code == 200
714+
assert response.json() == 2
715+
716+
await cleanup_all_exit_stacks()
717+
sync_cleanup_mock_override.assert_not_called()
718+
override_sync_cleanup_mock.assert_called_once()
719+
720+
721+
@pytest.mark.asyncio
722+
async def test_get_injected_obj_with_dependency_override_async_generator(
723+
clean_exit_stack_manager: None,
724+
) -> None:
725+
"""Tests that get_injected_obj respects dependency_overrides for async generators."""
726+
async_cleanup_mock_override = Mock()
727+
728+
async def async_gen_dependency_override() -> AsyncGenerator[int, None]:
729+
try:
730+
yield 1
731+
finally:
732+
async_cleanup_mock_override()
733+
734+
def use_async_gen_dependency_override() -> int:
735+
return get_injected_obj(async_gen_dependency_override)
736+
737+
override_async_cleanup_mock = Mock()
738+
739+
async def override_async_gen() -> AsyncGenerator[int, None]:
740+
try:
741+
yield 2
742+
finally:
743+
override_async_cleanup_mock()
744+
745+
loop_manager.set_loop_strategy(
746+
"background_thread"
747+
) # To avoid affecting the FastAPI app and httpx client event loop
748+
app = FastAPI()
749+
750+
@app.get("/")
751+
async def read_root() -> int:
752+
return use_async_gen_dependency_override()
753+
754+
app.dependency_overrides[async_gen_dependency_override] = override_async_gen
755+
await register_app(app)
756+
757+
with TestClient(app) as client:
758+
response = client.get("/")
759+
assert response.status_code == 200
760+
assert response.json() == 2
761+
762+
await cleanup_all_exit_stacks()
763+
async_cleanup_mock_override.assert_not_called()
764+
override_async_cleanup_mock.assert_called_once()

0 commit comments

Comments
 (0)