11"""Test that cancelled requests don't cause double responses."""
22
3- import asyncio
4- from unittest .mock import MagicMock
5-
3+ import anyio
64import pytest
75
86import mcp .types as types
97from mcp .server .lowlevel .server import Server
10- from mcp .types import PingRequest
11-
12-
13- # Shared mock class
14- class MockRequestResponder :
15- def __init__ (self ):
16- self .request_id = "test-123"
17- self ._responded = False
18- self .request_meta = {}
19- self .message_metadata = None
20-
21- async def send (self , response ):
22- if self ._responded :
23- raise AssertionError (f"Request { self .request_id } already responded to" )
24- self ._responded = True
25-
26- async def respond (self , response ):
27- await self .send (response )
28-
29- def cancel (self ):
30- """Simulate the cancel() method sending an error response."""
31- asyncio .create_task (self .send (types .ErrorData (code = - 32800 , message = "Request cancelled" )))
8+ from mcp .shared .exceptions import McpError
9+ from mcp .shared .memory import create_connected_server_and_client_session
10+ from mcp .types import (
11+ CallToolRequest ,
12+ CallToolRequestParams ,
13+ CallToolResult ,
14+ CancelledNotification ,
15+ CancelledNotificationParams ,
16+ ClientNotification ,
17+ ClientRequest ,
18+ Tool ,
19+ )
3220
3321
3422@pytest .mark .anyio
3523async def test_cancelled_request_no_double_response ():
3624 """Verify server handles cancelled requests without double response."""
3725
38- # Create a server instance
26+ # Create server with a slow tool
3927 server = Server ("test-server" )
4028
41- # Track if multiple responses are attempted
42- response_count = 0
43-
44- # Override the send method to track calls
45- mock_message = MockRequestResponder ()
46- original_send = mock_message .send
47-
48- async def tracked_send (response ):
49- nonlocal response_count
50- response_count += 1
51- await original_send (response )
52-
53- mock_message .send = tracked_send
54-
55- # Create a slow handler that will be cancelled
56- async def slow_handler (req ):
57- await asyncio .sleep (10 )
58- return types .ServerResult (types .EmptyResult ())
59-
60- # Use PingRequest as it's a valid request type
61- server .request_handlers [types .PingRequest ] = slow_handler
62-
63- # Create mock message and session
64- mock_req = PingRequest (method = "ping" )
65- mock_session = MagicMock ()
66- mock_context = None
67-
68- # Start the request
69- handle_task = asyncio .create_task (
70- server ._handle_request (mock_message , mock_req , mock_session , mock_context , raise_exceptions = False ) # type: ignore
71- )
72-
73- # Give it time to start
74- await asyncio .sleep (0.1 )
75-
76- # Simulate cancellation
77- mock_message .cancel ()
78- handle_task .cancel ()
79-
80- # Wait for cancellation to propagate
81- try :
82- await handle_task
83- except asyncio .CancelledError :
84- pass
85-
86- # Give time for any duplicate response attempts
87- await asyncio .sleep (0.1 )
88-
89- # Should only have one response (from cancel())
90- assert response_count == 1 , f"Expected 1 response, got { response_count } "
29+ # Track when tool is called
30+ ev_tool_called = anyio .Event ()
31+ request_id = None
32+
33+ @server .list_tools ()
34+ async def handle_list_tools () -> list [Tool ]:
35+ return [
36+ Tool (
37+ name = "slow_tool" ,
38+ description = "A slow tool for testing cancellation" ,
39+ inputSchema = {},
40+ )
41+ ]
42+
43+ @server .call_tool ()
44+ async def handle_call_tool (name : str , arguments : dict | None ) -> list :
45+ nonlocal request_id
46+ if name == "slow_tool" :
47+ request_id = server .request_context .request_id
48+ ev_tool_called .set ()
49+ await anyio .sleep (10 ) # Long running operation
50+ return [types .TextContent (type = "text" , text = "Tool called" )]
51+ raise ValueError (f"Unknown tool: { name } " )
52+
53+ # Connect client to server
54+ async with create_connected_server_and_client_session (server ) as client :
55+ # Start the slow tool call in a separate task
56+ async def make_request ():
57+ try :
58+ await client .send_request (
59+ ClientRequest (
60+ CallToolRequest (
61+ method = "tools/call" ,
62+ params = CallToolRequestParams (name = "slow_tool" , arguments = {}),
63+ )
64+ ),
65+ CallToolResult ,
66+ )
67+ pytest .fail ("Request should have been cancelled" )
68+ except McpError as e :
69+ # Expected - request was cancelled
70+ assert e .error .code == 0 # Request cancelled error code
71+
72+ # Start the request
73+ request_task = anyio .create_task_group ()
74+ async with request_task :
75+ request_task .start_soon (make_request )
76+
77+ # Wait for tool to start executing
78+ await ev_tool_called .wait ()
79+
80+ # Send cancellation notification
81+ assert request_id is not None
82+ await client .send_notification (
83+ ClientNotification (
84+ CancelledNotification (
85+ method = "notifications/cancelled" ,
86+ params = CancelledNotificationParams (
87+ requestId = request_id ,
88+ reason = "Test cancellation" ,
89+ ),
90+ )
91+ )
92+ )
93+
94+ # The request should be cancelled and raise McpError
9195
9296
9397@pytest .mark .anyio
@@ -96,43 +100,87 @@ async def test_server_remains_functional_after_cancel():
96100
97101 server = Server ("test-server" )
98102
99- # Add handlers
100- async def slow_handler (req ):
101- await asyncio .sleep (5 )
102- return types .ServerResult (types .EmptyResult ())
103-
104- async def fast_handler (req ):
105- return types .ServerResult (types .EmptyResult ())
106-
107- # Override ping handler for our test
108- server .request_handlers [types .PingRequest ] = slow_handler
109-
110- # First request (will be cancelled)
111- mock_message1 = MockRequestResponder ()
112- mock_req1 = PingRequest (method = "ping" )
113-
114- handle_task = asyncio .create_task (
115- server ._handle_request (mock_message1 , mock_req1 , MagicMock (), None , raise_exceptions = False ) # type: ignore
116- )
117-
118- await asyncio .sleep (0.1 )
119- mock_message1 .cancel ()
120- handle_task .cancel ()
121-
122- try :
123- await handle_task
124- except asyncio .CancelledError :
125- pass
126-
127- # Change handler to fast one
128- server .request_handlers [types .PingRequest ] = fast_handler
129-
130- # Second request (should work normally)
131- mock_message2 = MockRequestResponder ()
132- mock_req2 = PingRequest (method = "ping" )
133-
134- # This should complete successfully
135- await server ._handle_request (mock_message2 , mock_req2 , MagicMock (), None , raise_exceptions = False ) # type: ignore
136-
137- # Server handled the second request successfully
138- assert mock_message2 ._responded
103+ # Track tool calls
104+ call_count = 0
105+ ev_first_call = anyio .Event ()
106+ first_request_id = None
107+
108+ @server .list_tools ()
109+ async def handle_list_tools () -> list [Tool ]:
110+ return [
111+ Tool (
112+ name = "test_tool" ,
113+ description = "Tool for testing" ,
114+ inputSchema = {},
115+ )
116+ ]
117+
118+ @server .call_tool ()
119+ async def handle_call_tool (name : str , arguments : dict | None ) -> list :
120+ nonlocal call_count , first_request_id
121+ if name == "test_tool" :
122+ call_count += 1
123+ if call_count == 1 :
124+ first_request_id = server .request_context .request_id
125+ ev_first_call .set ()
126+ await anyio .sleep (5 ) # First call is slow
127+ return [types .TextContent (type = "text" , text = f"Call number: { call_count } " )]
128+ raise ValueError (f"Unknown tool: { name } " )
129+
130+ async with create_connected_server_and_client_session (server ) as client :
131+ # First request (will be cancelled)
132+ async def first_request ():
133+ try :
134+ await client .send_request (
135+ ClientRequest (
136+ CallToolRequest (
137+ method = "tools/call" ,
138+ params = CallToolRequestParams (name = "test_tool" , arguments = {}),
139+ )
140+ ),
141+ CallToolResult ,
142+ )
143+ pytest .fail ("First request should have been cancelled" )
144+ except McpError :
145+ pass # Expected
146+
147+ # Start first request
148+ async with anyio .create_task_group () as tg :
149+ tg .start_soon (first_request )
150+
151+ # Wait for it to start
152+ await ev_first_call .wait ()
153+
154+ # Cancel it
155+ assert first_request_id is not None
156+ await client .send_notification (
157+ ClientNotification (
158+ CancelledNotification (
159+ method = "notifications/cancelled" ,
160+ params = CancelledNotificationParams (
161+ requestId = first_request_id ,
162+ reason = "Testing server recovery" ,
163+ ),
164+ )
165+ )
166+ )
167+
168+ # Second request (should work normally)
169+ result = await client .send_request (
170+ ClientRequest (
171+ CallToolRequest (
172+ method = "tools/call" ,
173+ params = CallToolRequestParams (name = "test_tool" , arguments = {}),
174+ )
175+ ),
176+ CallToolResult ,
177+ )
178+
179+ # Verify second request completed successfully
180+ assert len (result .content ) == 1
181+ # Type narrowing for pyright
182+ content = result .content [0 ]
183+ assert content .type == "text"
184+ assert isinstance (content , types .TextContent )
185+ assert content .text == "Call number: 2"
186+ assert call_count == 2
0 commit comments