Skip to content
44 changes: 18 additions & 26 deletions tests/client/test_auth_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ def store():


@pytest.mark.asyncio
async def test_auth_interceptor_skips_when_no_agent_card(store):
"""
Tests that the AuthInterceptor does not modify the request when no AgentCard is provided.
"""
async def test_auth_interceptor_skips_when_no_agent_card(
store: InMemoryContextCredentialStore,
) -> None:
"""Tests that the AuthInterceptor does not modify the request when no AgentCard is provided."""
request_payload = {'foo': 'bar'}
http_kwargs = {'fizz': 'buzz'}
auth_interceptor = AuthInterceptor(credential_service=store)
Expand All @@ -126,9 +126,10 @@ async def test_auth_interceptor_skips_when_no_agent_card(store):


@pytest.mark.asyncio
async def test_in_memory_context_credential_store(store):
"""
Verifies that InMemoryContextCredentialStore correctly stores and retrieves
async def test_in_memory_context_credential_store(
store: InMemoryContextCredentialStore,
) -> None:
"""Verifies that InMemoryContextCredentialStore correctly stores and retrieves
credentials based on the session ID in the client context.
"""
session_id = 'session-id'
Expand Down Expand Up @@ -163,11 +164,8 @@ async def test_in_memory_context_credential_store(store):

@pytest.mark.asyncio
@respx.mock
async def test_client_with_simple_interceptor():
"""
Ensures that a custom HeaderInterceptor correctly injects a static header
into outbound HTTP requests from the A2AClient.
"""
async def test_client_with_simple_interceptor() -> None:
"""Ensures that a custom HeaderInterceptor correctly injects a static header into outbound HTTP requests from the A2AClient."""
url = 'http://agent.com/rpc'
interceptor = HeaderInterceptor('X-Test-Header', 'Test-Value-123')
card = AgentCard(
Expand Down Expand Up @@ -196,9 +194,7 @@ async def test_client_with_simple_interceptor():

@dataclass
class AuthTestCase:
"""
Represents a test scenario for verifying authentication behavior in AuthInterceptor.
"""
"""Represents a test scenario for verifying authentication behavior in AuthInterceptor."""

url: str
"""The endpoint URL of the agent to which the request is sent."""
Expand Down Expand Up @@ -284,11 +280,10 @@ class AuthTestCase:
[api_key_test_case, oauth2_test_case, oidc_test_case, bearer_test_case],
)
@respx.mock
async def test_auth_interceptor_variants(test_case, store):
"""
Parametrized test verifying that AuthInterceptor correctly attaches credentials
based on the defined security scheme in the AgentCard.
"""
async def test_auth_interceptor_variants(
test_case: AuthTestCase, store: InMemoryContextCredentialStore
) -> None:
"""Parametrized test verifying that AuthInterceptor correctly attaches credentials based on the defined security scheme in the AgentCard."""
await store.set_credentials(
test_case.session_id, test_case.scheme_name, test_case.credential
)
Expand Down Expand Up @@ -329,12 +324,9 @@ async def test_auth_interceptor_variants(test_case, store):

@pytest.mark.asyncio
async def test_auth_interceptor_skips_when_scheme_not_in_security_schemes(
store,
):
"""
Tests that AuthInterceptor skips a scheme if it's listed in security requirements
but not defined in security_schemes.
"""
store: InMemoryContextCredentialStore,
) -> None:
"""Tests that AuthInterceptor skips a scheme if it's listed in security requirements but not defined in security_schemes."""
scheme_name = 'missing'
session_id = 'session-id'
credential = 'dummy-token'
Expand Down
16 changes: 9 additions & 7 deletions tests/client/test_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@


@pytest.fixture
def mock_transport():
def mock_transport() -> AsyncMock:
return AsyncMock(spec=ClientTransport)


@pytest.fixture
def sample_agent_card():
def sample_agent_card() -> AgentCard:
return AgentCard(
name='Test Agent',
description='An agent for testing',
Expand All @@ -38,7 +38,7 @@ def sample_agent_card():


@pytest.fixture
def sample_message():
def sample_message() -> Message:
return Message(
role=Role.user,
message_id='msg-1',
Expand All @@ -47,7 +47,9 @@ def sample_message():


@pytest.fixture
def base_client(sample_agent_card, mock_transport):
def base_client(
sample_agent_card: AgentCard, mock_transport: AsyncMock
) -> BaseClient:
config = ClientConfig(streaming=True)
return BaseClient(
card=sample_agent_card,
Expand All @@ -61,7 +63,7 @@ def base_client(sample_agent_card, mock_transport):
@pytest.mark.asyncio
async def test_send_message_streaming(
base_client: BaseClient, mock_transport: MagicMock, sample_message: Message
):
) -> None:
async def create_stream(*args, **kwargs):
yield Task(
id='task-123',
Expand All @@ -82,7 +84,7 @@ async def create_stream(*args, **kwargs):
@pytest.mark.asyncio
async def test_send_message_non_streaming(
base_client: BaseClient, mock_transport: MagicMock, sample_message: Message
):
) -> None:
base_client._config.streaming = False
mock_transport.send_message.return_value = Task(
id='task-456',
Expand All @@ -101,7 +103,7 @@ async def test_send_message_non_streaming(
@pytest.mark.asyncio
async def test_send_message_non_streaming_agent_capability_false(
base_client: BaseClient, mock_transport: MagicMock, sample_message: Message
):
) -> None:
base_client._card.capabilities.streaming = False
mock_transport.send_message.return_value = Task(
id='task-789',
Expand Down
32 changes: 18 additions & 14 deletions tests/client/test_client_task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@


@pytest.fixture
def task_manager():
def task_manager() -> ClientTaskManager:
return ClientTaskManager()


@pytest.fixture
def sample_task():
def sample_task() -> Task:
return Task(
id='task123',
context_id='context456',
Expand All @@ -38,29 +38,31 @@ def sample_task():


@pytest.fixture
def sample_message():
def sample_message() -> Message:
return Message(
message_id='msg1',
role=Role.user,
parts=[Part(root=TextPart(text='Hello'))],
)


def test_get_task_no_task_id_returns_none(task_manager: ClientTaskManager):
def test_get_task_no_task_id_returns_none(
task_manager: ClientTaskManager,
) -> None:
assert task_manager.get_task() is None


def test_get_task_or_raise_no_task_raises_error(
task_manager: ClientTaskManager,
):
) -> None:
with pytest.raises(A2AClientInvalidStateError, match='no current Task'):
task_manager.get_task_or_raise()


@pytest.mark.asyncio
async def test_save_task_event_with_task(
task_manager: ClientTaskManager, sample_task: Task
):
) -> None:
await task_manager.save_task_event(sample_task)
assert task_manager.get_task() == sample_task
assert task_manager._task_id == sample_task.id
Expand All @@ -70,7 +72,7 @@ async def test_save_task_event_with_task(
@pytest.mark.asyncio
async def test_save_task_event_with_task_already_set_raises_error(
task_manager: ClientTaskManager, sample_task: Task
):
) -> None:
await task_manager.save_task_event(sample_task)
with pytest.raises(
A2AClientInvalidArgsError,
Expand All @@ -82,7 +84,7 @@ async def test_save_task_event_with_task_already_set_raises_error(
@pytest.mark.asyncio
async def test_save_task_event_with_status_update(
task_manager: ClientTaskManager, sample_task: Task, sample_message: Message
):
) -> None:
await task_manager.save_task_event(sample_task)
status_update = TaskStatusUpdateEvent(
task_id=sample_task.id,
Expand All @@ -98,7 +100,7 @@ async def test_save_task_event_with_status_update(
@pytest.mark.asyncio
async def test_save_task_event_with_artifact_update(
task_manager: ClientTaskManager, sample_task: Task
):
) -> None:
await task_manager.save_task_event(sample_task)
artifact = Artifact(
artifact_id='art1', parts=[Part(root=TextPart(text='artifact content'))]
Expand All @@ -119,7 +121,7 @@ async def test_save_task_event_with_artifact_update(
@pytest.mark.asyncio
async def test_save_task_event_creates_task_if_not_exists(
task_manager: ClientTaskManager,
):
) -> None:
status_update = TaskStatusUpdateEvent(
task_id='new_task',
context_id='new_context',
Expand All @@ -135,7 +137,7 @@ async def test_save_task_event_creates_task_if_not_exists(
@pytest.mark.asyncio
async def test_process_with_task_event(
task_manager: ClientTaskManager, sample_task: Task
):
) -> None:
with patch.object(
task_manager, 'save_task_event', new_callable=AsyncMock
) as mock_save:
Expand All @@ -144,7 +146,9 @@ async def test_process_with_task_event(


@pytest.mark.asyncio
async def test_process_with_non_task_event(task_manager: ClientTaskManager):
async def test_process_with_non_task_event(
task_manager: ClientTaskManager,
) -> None:
with patch.object(
task_manager, 'save_task_event', new_callable=Mock
) as mock_save:
Expand All @@ -155,14 +159,14 @@ async def test_process_with_non_task_event(task_manager: ClientTaskManager):

def test_update_with_message(
task_manager: ClientTaskManager, sample_task: Task, sample_message: Message
):
) -> None:
updated_task = task_manager.update_with_message(sample_message, sample_task)
assert updated_task.history == [sample_message]


def test_update_with_message_moves_status_message(
task_manager: ClientTaskManager, sample_task: Task, sample_message: Message
):
) -> None:
status_message = Message(
message_id='status_msg',
role=Role.agent,
Expand Down
Loading
Loading