| 
1 | 1 | from collections.abc import AsyncGenerator, Generator  | 
2 | 2 | from typing import Annotated, Any  | 
 | 3 | +from unittest.mock import Mock  | 
3 | 4 | 
 
  | 
4 | 5 | import pytest  | 
5 |  | -from fastapi import Depends  | 
 | 6 | +from fastapi import Depends, FastAPI  | 
 | 7 | +from fastapi.testclient import TestClient  | 
6 | 8 | 
 
  | 
7 |  | -from src.fastapi_injectable.concurrency import run_coroutine_sync  | 
 | 9 | +from src.fastapi_injectable import register_app  | 
 | 10 | +from src.fastapi_injectable.concurrency import loop_manager, run_coroutine_sync  | 
8 | 11 | from src.fastapi_injectable.decorator import injectable  | 
9 |  | -from src.fastapi_injectable.util import cleanup_all_exit_stacks, cleanup_exit_stack_of_func, get_injected_obj  | 
 | 12 | +from src.fastapi_injectable.util import (  | 
 | 13 | +    cleanup_all_exit_stacks,  | 
 | 14 | +    cleanup_exit_stack_of_func,  | 
 | 15 | +    get_injected_obj,  | 
 | 16 | +)  | 
10 | 17 | 
 
  | 
11 | 18 | 
 
  | 
12 | 19 | @pytest.fixture  | 
@@ -604,3 +611,154 @@ def get_country(  | 
604 | 611 | 
 
  | 
605 | 612 |     country: Country = get_country(basic_str="basic_str", basic_int=1, basic_bool=True, basic_dict={"key": "value"})  # type: ignore  # noqa: PGH003  | 
606 | 613 |     assert country is not None  | 
 | 614 | + | 
 | 615 | + | 
 | 616 | +def test_get_injected_obj_with_dependency_override_sync(clean_exit_stack_manager: None) -> None:  | 
 | 617 | +    """Tests that get_injected_obj respects dependency_overrides for sync dependencies."""  | 
 | 618 | + | 
 | 619 | +    def sync_dependency_override() -> int:  | 
 | 620 | +        return 1  | 
 | 621 | + | 
 | 622 | +    def use_sync_dependency_override() -> int:  | 
 | 623 | +        return get_injected_obj(sync_dependency_override)  | 
 | 624 | + | 
 | 625 | +    loop_manager.set_loop_strategy(  | 
 | 626 | +        "background_thread"  | 
 | 627 | +    )  # To avoid affecting the FastAPI app and httpx client event loop  | 
 | 628 | +    app = FastAPI()  | 
 | 629 | + | 
 | 630 | +    @app.get("/")  | 
 | 631 | +    def read_root() -> int:  | 
 | 632 | +        return use_sync_dependency_override()  | 
 | 633 | + | 
 | 634 | +    mock_dependency = Mock(return_value=2)  | 
 | 635 | +    app.dependency_overrides[sync_dependency_override] = lambda: mock_dependency()  | 
 | 636 | +    run_coroutine_sync(register_app(app))  | 
 | 637 | + | 
 | 638 | +    client = TestClient(app)  | 
 | 639 | +    response = client.get("/")  | 
 | 640 | +    assert response.status_code == 200  | 
 | 641 | +    assert response.json() == 2  | 
 | 642 | +    mock_dependency.assert_called_once()  | 
 | 643 | + | 
 | 644 | + | 
 | 645 | +@pytest.mark.asyncio  | 
 | 646 | +async def test_get_injected_obj_with_dependency_override_async(clean_exit_stack_manager: None) -> None:  | 
 | 647 | +    """Tests that get_injected_obj respects dependency_overrides for async dependencies."""  | 
 | 648 | + | 
 | 649 | +    async def async_dependency_override() -> int:  | 
 | 650 | +        return 1  | 
 | 651 | + | 
 | 652 | +    def use_async_dependency_override() -> int:  | 
 | 653 | +        return get_injected_obj(async_dependency_override)  | 
 | 654 | + | 
 | 655 | +    loop_manager.set_loop_strategy(  | 
 | 656 | +        "background_thread"  | 
 | 657 | +    )  # To avoid affecting the FastAPI app and httpx client event loop  | 
 | 658 | +    app = FastAPI()  | 
 | 659 | + | 
 | 660 | +    @app.get("/")  | 
 | 661 | +    async def read_root() -> int:  | 
 | 662 | +        return use_async_dependency_override()  | 
 | 663 | + | 
 | 664 | +    mock_dependency = Mock(return_value=2)  | 
 | 665 | +    app.dependency_overrides[async_dependency_override] = lambda: mock_dependency()  | 
 | 666 | +    await register_app(app)  | 
 | 667 | + | 
 | 668 | +    client = TestClient(app)  | 
 | 669 | +    response = client.get("/")  | 
 | 670 | +    assert response.status_code == 200  | 
 | 671 | +    assert response.json() == 2  | 
 | 672 | +    mock_dependency.assert_called_once()  | 
 | 673 | + | 
 | 674 | + | 
 | 675 | +@pytest.mark.asyncio  | 
 | 676 | +async def test_get_injected_obj_with_dependency_override_sync_generator(  | 
 | 677 | +    clean_exit_stack_manager: None,  | 
 | 678 | +) -> None:  | 
 | 679 | +    """Tests that get_injected_obj respects dependency_overrides for sync generators."""  | 
 | 680 | +    sync_cleanup_mock_override = Mock()  | 
 | 681 | + | 
 | 682 | +    def sync_gen_dependency_override() -> Generator[int, None, None]:  | 
 | 683 | +        try:  | 
 | 684 | +            yield 1  | 
 | 685 | +        finally:  | 
 | 686 | +            sync_cleanup_mock_override()  | 
 | 687 | + | 
 | 688 | +    def use_sync_gen_dependency_override() -> int:  | 
 | 689 | +        return get_injected_obj(sync_gen_dependency_override)  | 
 | 690 | + | 
 | 691 | +    override_sync_cleanup_mock = Mock()  | 
 | 692 | + | 
 | 693 | +    def override_sync_gen() -> Generator[int, None, None]:  | 
 | 694 | +        try:  | 
 | 695 | +            yield 2  | 
 | 696 | +        finally:  | 
 | 697 | +            override_sync_cleanup_mock()  | 
 | 698 | + | 
 | 699 | +    loop_manager.set_loop_strategy(  | 
 | 700 | +        "background_thread"  | 
 | 701 | +    )  # To avoid affecting the FastAPI app and httpx client event loop  | 
 | 702 | +    app = FastAPI()  | 
 | 703 | + | 
 | 704 | +    @app.get("/")  | 
 | 705 | +    def read_root() -> int:  | 
 | 706 | +        return use_sync_gen_dependency_override()  | 
 | 707 | + | 
 | 708 | +    app.dependency_overrides[sync_gen_dependency_override] = override_sync_gen  | 
 | 709 | +    await register_app(app)  | 
 | 710 | + | 
 | 711 | +    with TestClient(app) as client:  | 
 | 712 | +        response = client.get("/")  | 
 | 713 | +        assert response.status_code == 200  | 
 | 714 | +        assert response.json() == 2  | 
 | 715 | + | 
 | 716 | +    await cleanup_all_exit_stacks()  | 
 | 717 | +    sync_cleanup_mock_override.assert_not_called()  | 
 | 718 | +    override_sync_cleanup_mock.assert_called_once()  | 
 | 719 | + | 
 | 720 | + | 
 | 721 | +@pytest.mark.asyncio  | 
 | 722 | +async def test_get_injected_obj_with_dependency_override_async_generator(  | 
 | 723 | +    clean_exit_stack_manager: None,  | 
 | 724 | +) -> None:  | 
 | 725 | +    """Tests that get_injected_obj respects dependency_overrides for async generators."""  | 
 | 726 | +    async_cleanup_mock_override = Mock()  | 
 | 727 | + | 
 | 728 | +    async def async_gen_dependency_override() -> AsyncGenerator[int, None]:  | 
 | 729 | +        try:  | 
 | 730 | +            yield 1  | 
 | 731 | +        finally:  | 
 | 732 | +            async_cleanup_mock_override()  | 
 | 733 | + | 
 | 734 | +    def use_async_gen_dependency_override() -> int:  | 
 | 735 | +        return get_injected_obj(async_gen_dependency_override)  | 
 | 736 | + | 
 | 737 | +    override_async_cleanup_mock = Mock()  | 
 | 738 | + | 
 | 739 | +    async def override_async_gen() -> AsyncGenerator[int, None]:  | 
 | 740 | +        try:  | 
 | 741 | +            yield 2  | 
 | 742 | +        finally:  | 
 | 743 | +            override_async_cleanup_mock()  | 
 | 744 | + | 
 | 745 | +    loop_manager.set_loop_strategy(  | 
 | 746 | +        "background_thread"  | 
 | 747 | +    )  # To avoid affecting the FastAPI app and httpx client event loop  | 
 | 748 | +    app = FastAPI()  | 
 | 749 | + | 
 | 750 | +    @app.get("/")  | 
 | 751 | +    async def read_root() -> int:  | 
 | 752 | +        return use_async_gen_dependency_override()  | 
 | 753 | + | 
 | 754 | +    app.dependency_overrides[async_gen_dependency_override] = override_async_gen  | 
 | 755 | +    await register_app(app)  | 
 | 756 | + | 
 | 757 | +    with TestClient(app) as client:  | 
 | 758 | +        response = client.get("/")  | 
 | 759 | +        assert response.status_code == 200  | 
 | 760 | +        assert response.json() == 2  | 
 | 761 | + | 
 | 762 | +    await cleanup_all_exit_stacks()  | 
 | 763 | +    async_cleanup_mock_override.assert_not_called()  | 
 | 764 | +    override_async_cleanup_mock.assert_called_once()  | 
0 commit comments