diff --git a/src/server/streamableHttp.test.ts b/src/server/streamableHttp.test.ts index 502435ead..e54bea017 100644 --- a/src/server/streamableHttp.test.ts +++ b/src/server/streamableHttp.test.ts @@ -29,6 +29,7 @@ interface TestServerConfig { enableJsonResponse?: boolean; customRequestHandler?: (req: IncomingMessage, res: ServerResponse, parsedBody?: unknown) => Promise; eventStore?: EventStore; + onsessionclosed?: (sessionId: string) => void; } /** @@ -57,7 +58,8 @@ async function createTestServer(config: TestServerConfig = { sessionIdGenerator: const transport = new StreamableHTTPServerTransport({ sessionIdGenerator: config.sessionIdGenerator, enableJsonResponse: config.enableJsonResponse ?? false, - eventStore: config.eventStore + eventStore: config.eventStore, + onsessionclosed: config.onsessionclosed }); await mcpServer.connect(transport); @@ -111,7 +113,8 @@ async function createTestAuthServer(config: TestServerConfig = { sessionIdGenera const transport = new StreamableHTTPServerTransport({ sessionIdGenerator: config.sessionIdGenerator, enableJsonResponse: config.enableJsonResponse ?? false, - eventStore: config.eventStore + eventStore: config.eventStore, + onsessionclosed: config.onsessionclosed }); await mcpServer.connect(transport); @@ -1504,6 +1507,165 @@ describe("StreamableHTTPServerTransport in stateless mode", () => { }); }); +// Test onsessionclosed callback +describe("StreamableHTTPServerTransport onsessionclosed callback", () => { + it("should call onsessionclosed callback when session is closed via DELETE", async () => { + const mockCallback = jest.fn(); + + // Create server with onsessionclosed callback + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessionclosed: mockCallback, + }); + + const tempServer = result.server; + const tempUrl = result.baseUrl; + + // Initialize to get a session ID + const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + const tempSessionId = initResponse.headers.get("mcp-session-id"); + expect(tempSessionId).toBeDefined(); + + // DELETE the session + const deleteResponse = await fetch(tempUrl, { + method: "DELETE", + headers: { + "mcp-session-id": tempSessionId || "", + "mcp-protocol-version": "2025-03-26", + }, + }); + + expect(deleteResponse.status).toBe(200); + expect(mockCallback).toHaveBeenCalledWith(tempSessionId); + expect(mockCallback).toHaveBeenCalledTimes(1); + + // Clean up + tempServer.close(); + }); + + it("should not call onsessionclosed callback when not provided", async () => { + // Create server without onsessionclosed callback + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + }); + + const tempServer = result.server; + const tempUrl = result.baseUrl; + + // Initialize to get a session ID + const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + const tempSessionId = initResponse.headers.get("mcp-session-id"); + + // DELETE the session - should not throw error + const deleteResponse = await fetch(tempUrl, { + method: "DELETE", + headers: { + "mcp-session-id": tempSessionId || "", + "mcp-protocol-version": "2025-03-26", + }, + }); + + expect(deleteResponse.status).toBe(200); + + // Clean up + tempServer.close(); + }); + + it("should not call onsessionclosed callback for invalid session DELETE", async () => { + const mockCallback = jest.fn(); + + // Create server with onsessionclosed callback + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessionclosed: mockCallback, + }); + + const tempServer = result.server; + const tempUrl = result.baseUrl; + + // Initialize to get a valid session + await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + + // Try to DELETE with invalid session ID + const deleteResponse = await fetch(tempUrl, { + method: "DELETE", + headers: { + "mcp-session-id": "invalid-session-id", + "mcp-protocol-version": "2025-03-26", + }, + }); + + expect(deleteResponse.status).toBe(404); + expect(mockCallback).not.toHaveBeenCalled(); + + // Clean up + tempServer.close(); + }); + + it("should call onsessionclosed callback with correct session ID when multiple sessions exist", async () => { + const mockCallback = jest.fn(); + + // Create first server + const result1 = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessionclosed: mockCallback, + }); + + const server1 = result1.server; + const url1 = result1.baseUrl; + + // Create second server + const result2 = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessionclosed: mockCallback, + }); + + const server2 = result2.server; + const url2 = result2.baseUrl; + + // Initialize both servers + const initResponse1 = await sendPostRequest(url1, TEST_MESSAGES.initialize); + const sessionId1 = initResponse1.headers.get("mcp-session-id"); + + const initResponse2 = await sendPostRequest(url2, TEST_MESSAGES.initialize); + const sessionId2 = initResponse2.headers.get("mcp-session-id"); + + expect(sessionId1).toBeDefined(); + expect(sessionId2).toBeDefined(); + expect(sessionId1).not.toBe(sessionId2); + + // DELETE first session + const deleteResponse1 = await fetch(url1, { + method: "DELETE", + headers: { + "mcp-session-id": sessionId1 || "", + "mcp-protocol-version": "2025-03-26", + }, + }); + + expect(deleteResponse1.status).toBe(200); + expect(mockCallback).toHaveBeenCalledWith(sessionId1); + expect(mockCallback).toHaveBeenCalledTimes(1); + + // DELETE second session + const deleteResponse2 = await fetch(url2, { + method: "DELETE", + headers: { + "mcp-session-id": sessionId2 || "", + "mcp-protocol-version": "2025-03-26", + }, + }); + + expect(deleteResponse2.status).toBe(200); + expect(mockCallback).toHaveBeenCalledWith(sessionId2); + expect(mockCallback).toHaveBeenCalledTimes(2); + + // Clean up + server1.close(); + server2.close(); + }); +}); + // Test DNS rebinding protection describe("StreamableHTTPServerTransport DNS rebinding protection", () => { let server: Server; diff --git a/src/server/streamableHttp.ts b/src/server/streamableHttp.ts index 022d1a474..37164c869 100644 --- a/src/server/streamableHttp.ts +++ b/src/server/streamableHttp.ts @@ -49,6 +49,18 @@ export interface StreamableHTTPServerTransportOptions { */ onsessioninitialized?: (sessionId: string) => void; + /** + * A callback for session close events + * This is called when the server closes a session due to a DELETE request. + * Useful in cases when you need to clean up resources associated with the session. + * Note that this is different from the transport closing, if you are handling + * HTTP requests from multiple nodes you might want to close each + * StreamableHTTPServerTransport after a request is completed while still keeping the + * session open/running. + * @param sessionId The session ID that was closed + */ + onsessionclosed?: (sessionId: string) => void; + /** * If true, the server will return JSON responses instead of starting an SSE stream. * This can be useful for simple request/response scenarios without streaming. @@ -127,6 +139,7 @@ export class StreamableHTTPServerTransport implements Transport { private _standaloneSseStreamId: string = '_GET_stream'; private _eventStore?: EventStore; private _onsessioninitialized?: (sessionId: string) => void; + private _onsessionclosed?: (sessionId: string) => void; private _allowedHosts?: string[]; private _allowedOrigins?: string[]; private _enableDnsRebindingProtection: boolean; @@ -141,6 +154,7 @@ export class StreamableHTTPServerTransport implements Transport { this._enableJsonResponse = options.enableJsonResponse ?? false; this._eventStore = options.eventStore; this._onsessioninitialized = options.onsessioninitialized; + this._onsessionclosed = options.onsessionclosed; this._allowedHosts = options.allowedHosts; this._allowedOrigins = options.allowedOrigins; this._enableDnsRebindingProtection = options.enableDnsRebindingProtection ?? false; @@ -538,6 +552,7 @@ export class StreamableHTTPServerTransport implements Transport { if (!this.validateProtocolVersion(req, res)) { return; } + this._onsessionclosed?.(this.sessionId!); await this.close(); res.writeHead(200).end(); }