|
1 | 1 | from collections.abc import AsyncGenerator, Generator |
2 | 2 | from typing import Annotated, Any |
| 3 | +from unittest.mock import Mock |
3 | 4 |
|
4 | 5 | import pytest |
5 | | -from fastapi import Depends |
| 6 | +from fastapi import Depends, FastAPI |
| 7 | +from fastapi.testclient import TestClient |
6 | 8 |
|
7 | | -from src.fastapi_injectable.concurrency import run_coroutine_sync |
| 9 | +from src.fastapi_injectable import register_app |
| 10 | +from src.fastapi_injectable.concurrency import loop_manager, run_coroutine_sync |
8 | 11 | from src.fastapi_injectable.decorator import injectable |
9 | | -from src.fastapi_injectable.util import cleanup_all_exit_stacks, cleanup_exit_stack_of_func, get_injected_obj |
| 12 | +from src.fastapi_injectable.util import ( |
| 13 | + cleanup_all_exit_stacks, |
| 14 | + cleanup_exit_stack_of_func, |
| 15 | + get_injected_obj, |
| 16 | +) |
10 | 17 |
|
11 | 18 |
|
12 | 19 | @pytest.fixture |
@@ -604,3 +611,154 @@ def get_country( |
604 | 611 |
|
605 | 612 | country: Country = get_country(basic_str="basic_str", basic_int=1, basic_bool=True, basic_dict={"key": "value"}) # type: ignore # noqa: PGH003 |
606 | 613 | assert country is not None |
| 614 | + |
| 615 | + |
| 616 | +def test_get_injected_obj_with_dependency_override_sync(clean_exit_stack_manager: None) -> None: |
| 617 | + """Tests that get_injected_obj respects dependency_overrides for sync dependencies.""" |
| 618 | + |
| 619 | + def sync_dependency_override() -> int: |
| 620 | + return 1 |
| 621 | + |
| 622 | + def use_sync_dependency_override() -> int: |
| 623 | + return get_injected_obj(sync_dependency_override) |
| 624 | + |
| 625 | + loop_manager.set_loop_strategy( |
| 626 | + "background_thread" |
| 627 | + ) # To avoid affecting the FastAPI app and httpx client event loop |
| 628 | + app = FastAPI() |
| 629 | + |
| 630 | + @app.get("/") |
| 631 | + def read_root() -> int: |
| 632 | + return use_sync_dependency_override() |
| 633 | + |
| 634 | + mock_dependency = Mock(return_value=2) |
| 635 | + app.dependency_overrides[sync_dependency_override] = lambda: mock_dependency() |
| 636 | + run_coroutine_sync(register_app(app)) |
| 637 | + |
| 638 | + client = TestClient(app) |
| 639 | + response = client.get("/") |
| 640 | + assert response.status_code == 200 |
| 641 | + assert response.json() == 2 |
| 642 | + mock_dependency.assert_called_once() |
| 643 | + |
| 644 | + |
| 645 | +@pytest.mark.asyncio |
| 646 | +async def test_get_injected_obj_with_dependency_override_async(clean_exit_stack_manager: None) -> None: |
| 647 | + """Tests that get_injected_obj respects dependency_overrides for async dependencies.""" |
| 648 | + |
| 649 | + async def async_dependency_override() -> int: |
| 650 | + return 1 |
| 651 | + |
| 652 | + def use_async_dependency_override() -> int: |
| 653 | + return get_injected_obj(async_dependency_override) |
| 654 | + |
| 655 | + loop_manager.set_loop_strategy( |
| 656 | + "background_thread" |
| 657 | + ) # To avoid affecting the FastAPI app and httpx client event loop |
| 658 | + app = FastAPI() |
| 659 | + |
| 660 | + @app.get("/") |
| 661 | + async def read_root() -> int: |
| 662 | + return use_async_dependency_override() |
| 663 | + |
| 664 | + mock_dependency = Mock(return_value=2) |
| 665 | + app.dependency_overrides[async_dependency_override] = lambda: mock_dependency() |
| 666 | + await register_app(app) |
| 667 | + |
| 668 | + client = TestClient(app) |
| 669 | + response = client.get("/") |
| 670 | + assert response.status_code == 200 |
| 671 | + assert response.json() == 2 |
| 672 | + mock_dependency.assert_called_once() |
| 673 | + |
| 674 | + |
| 675 | +@pytest.mark.asyncio |
| 676 | +async def test_get_injected_obj_with_dependency_override_sync_generator( |
| 677 | + clean_exit_stack_manager: None, |
| 678 | +) -> None: |
| 679 | + """Tests that get_injected_obj respects dependency_overrides for sync generators.""" |
| 680 | + sync_cleanup_mock_override = Mock() |
| 681 | + |
| 682 | + def sync_gen_dependency_override() -> Generator[int, None, None]: |
| 683 | + try: |
| 684 | + yield 1 |
| 685 | + finally: |
| 686 | + sync_cleanup_mock_override() |
| 687 | + |
| 688 | + def use_sync_gen_dependency_override() -> int: |
| 689 | + return get_injected_obj(sync_gen_dependency_override) |
| 690 | + |
| 691 | + override_sync_cleanup_mock = Mock() |
| 692 | + |
| 693 | + def override_sync_gen() -> Generator[int, None, None]: |
| 694 | + try: |
| 695 | + yield 2 |
| 696 | + finally: |
| 697 | + override_sync_cleanup_mock() |
| 698 | + |
| 699 | + loop_manager.set_loop_strategy( |
| 700 | + "background_thread" |
| 701 | + ) # To avoid affecting the FastAPI app and httpx client event loop |
| 702 | + app = FastAPI() |
| 703 | + |
| 704 | + @app.get("/") |
| 705 | + def read_root() -> int: |
| 706 | + return use_sync_gen_dependency_override() |
| 707 | + |
| 708 | + app.dependency_overrides[sync_gen_dependency_override] = override_sync_gen |
| 709 | + await register_app(app) |
| 710 | + |
| 711 | + with TestClient(app) as client: |
| 712 | + response = client.get("/") |
| 713 | + assert response.status_code == 200 |
| 714 | + assert response.json() == 2 |
| 715 | + |
| 716 | + await cleanup_all_exit_stacks() |
| 717 | + sync_cleanup_mock_override.assert_not_called() |
| 718 | + override_sync_cleanup_mock.assert_called_once() |
| 719 | + |
| 720 | + |
| 721 | +@pytest.mark.asyncio |
| 722 | +async def test_get_injected_obj_with_dependency_override_async_generator( |
| 723 | + clean_exit_stack_manager: None, |
| 724 | +) -> None: |
| 725 | + """Tests that get_injected_obj respects dependency_overrides for async generators.""" |
| 726 | + async_cleanup_mock_override = Mock() |
| 727 | + |
| 728 | + async def async_gen_dependency_override() -> AsyncGenerator[int, None]: |
| 729 | + try: |
| 730 | + yield 1 |
| 731 | + finally: |
| 732 | + async_cleanup_mock_override() |
| 733 | + |
| 734 | + def use_async_gen_dependency_override() -> int: |
| 735 | + return get_injected_obj(async_gen_dependency_override) |
| 736 | + |
| 737 | + override_async_cleanup_mock = Mock() |
| 738 | + |
| 739 | + async def override_async_gen() -> AsyncGenerator[int, None]: |
| 740 | + try: |
| 741 | + yield 2 |
| 742 | + finally: |
| 743 | + override_async_cleanup_mock() |
| 744 | + |
| 745 | + loop_manager.set_loop_strategy( |
| 746 | + "background_thread" |
| 747 | + ) # To avoid affecting the FastAPI app and httpx client event loop |
| 748 | + app = FastAPI() |
| 749 | + |
| 750 | + @app.get("/") |
| 751 | + async def read_root() -> int: |
| 752 | + return use_async_gen_dependency_override() |
| 753 | + |
| 754 | + app.dependency_overrides[async_gen_dependency_override] = override_async_gen |
| 755 | + await register_app(app) |
| 756 | + |
| 757 | + with TestClient(app) as client: |
| 758 | + response = client.get("/") |
| 759 | + assert response.status_code == 200 |
| 760 | + assert response.json() == 2 |
| 761 | + |
| 762 | + await cleanup_all_exit_stacks() |
| 763 | + async_cleanup_mock_override.assert_not_called() |
| 764 | + override_async_cleanup_mock.assert_called_once() |
0 commit comments