Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 164 additions & 2 deletions src/server/streamableHttp.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ interface TestServerConfig {
enableJsonResponse?: boolean;
customRequestHandler?: (req: IncomingMessage, res: ServerResponse, parsedBody?: unknown) => Promise<void>;
eventStore?: EventStore;
onsessionclosed?: (sessionId: string) => void;
}

/**
Expand Down Expand Up @@ -57,7 +58,8 @@ async function createTestServer(config: TestServerConfig = { sessionIdGenerator:
const transport = new StreamableHTTPServerTransport({
sessionIdGenerator: config.sessionIdGenerator,
enableJsonResponse: config.enableJsonResponse ?? false,
eventStore: config.eventStore
eventStore: config.eventStore,
onsessionclosed: config.onsessionclosed
});

await mcpServer.connect(transport);
Expand Down Expand Up @@ -111,7 +113,8 @@ async function createTestAuthServer(config: TestServerConfig = { sessionIdGenera
const transport = new StreamableHTTPServerTransport({
sessionIdGenerator: config.sessionIdGenerator,
enableJsonResponse: config.enableJsonResponse ?? false,
eventStore: config.eventStore
eventStore: config.eventStore,
onsessionclosed: config.onsessionclosed
});

await mcpServer.connect(transport);
Expand Down Expand Up @@ -1504,6 +1507,165 @@ describe("StreamableHTTPServerTransport in stateless mode", () => {
});
});

// Test onsessionclosed callback
describe("StreamableHTTPServerTransport onsessionclosed callback", () => {
it("should call onsessionclosed callback when session is closed via DELETE", async () => {
const mockCallback = jest.fn();

// Create server with onsessionclosed callback
const result = await createTestServer({
sessionIdGenerator: () => randomUUID(),
onsessionclosed: mockCallback,
});

const tempServer = result.server;
const tempUrl = result.baseUrl;

// Initialize to get a session ID
const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize);
const tempSessionId = initResponse.headers.get("mcp-session-id");
expect(tempSessionId).toBeDefined();

// DELETE the session
const deleteResponse = await fetch(tempUrl, {
method: "DELETE",
headers: {
"mcp-session-id": tempSessionId || "",
"mcp-protocol-version": "2025-03-26",
},
});

expect(deleteResponse.status).toBe(200);
expect(mockCallback).toHaveBeenCalledWith(tempSessionId);
expect(mockCallback).toHaveBeenCalledTimes(1);

// Clean up
tempServer.close();
});

it("should not call onsessionclosed callback when not provided", async () => {
// Create server without onsessionclosed callback
const result = await createTestServer({
sessionIdGenerator: () => randomUUID(),
});

const tempServer = result.server;
const tempUrl = result.baseUrl;

// Initialize to get a session ID
const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize);
const tempSessionId = initResponse.headers.get("mcp-session-id");

// DELETE the session - should not throw error
const deleteResponse = await fetch(tempUrl, {
method: "DELETE",
headers: {
"mcp-session-id": tempSessionId || "",
"mcp-protocol-version": "2025-03-26",
},
});

expect(deleteResponse.status).toBe(200);

// Clean up
tempServer.close();
});

it("should not call onsessionclosed callback for invalid session DELETE", async () => {
const mockCallback = jest.fn();

// Create server with onsessionclosed callback
const result = await createTestServer({
sessionIdGenerator: () => randomUUID(),
onsessionclosed: mockCallback,
});

const tempServer = result.server;
const tempUrl = result.baseUrl;

// Initialize to get a valid session
await sendPostRequest(tempUrl, TEST_MESSAGES.initialize);

// Try to DELETE with invalid session ID
const deleteResponse = await fetch(tempUrl, {
method: "DELETE",
headers: {
"mcp-session-id": "invalid-session-id",
"mcp-protocol-version": "2025-03-26",
},
});

expect(deleteResponse.status).toBe(404);
expect(mockCallback).not.toHaveBeenCalled();

// Clean up
tempServer.close();
});

it("should call onsessionclosed callback with correct session ID when multiple sessions exist", async () => {
const mockCallback = jest.fn();

// Create first server
const result1 = await createTestServer({
sessionIdGenerator: () => randomUUID(),
onsessionclosed: mockCallback,
});

const server1 = result1.server;
const url1 = result1.baseUrl;

// Create second server
const result2 = await createTestServer({
sessionIdGenerator: () => randomUUID(),
onsessionclosed: mockCallback,
});

const server2 = result2.server;
const url2 = result2.baseUrl;

// Initialize both servers
const initResponse1 = await sendPostRequest(url1, TEST_MESSAGES.initialize);
const sessionId1 = initResponse1.headers.get("mcp-session-id");

const initResponse2 = await sendPostRequest(url2, TEST_MESSAGES.initialize);
const sessionId2 = initResponse2.headers.get("mcp-session-id");

expect(sessionId1).toBeDefined();
expect(sessionId2).toBeDefined();
expect(sessionId1).not.toBe(sessionId2);

// DELETE first session
const deleteResponse1 = await fetch(url1, {
method: "DELETE",
headers: {
"mcp-session-id": sessionId1 || "",
"mcp-protocol-version": "2025-03-26",
},
});

expect(deleteResponse1.status).toBe(200);
expect(mockCallback).toHaveBeenCalledWith(sessionId1);
expect(mockCallback).toHaveBeenCalledTimes(1);

// DELETE second session
const deleteResponse2 = await fetch(url2, {
method: "DELETE",
headers: {
"mcp-session-id": sessionId2 || "",
"mcp-protocol-version": "2025-03-26",
},
});

expect(deleteResponse2.status).toBe(200);
expect(mockCallback).toHaveBeenCalledWith(sessionId2);
expect(mockCallback).toHaveBeenCalledTimes(2);

// Clean up
server1.close();
server2.close();
});
});

// Test DNS rebinding protection
describe("StreamableHTTPServerTransport DNS rebinding protection", () => {
let server: Server;
Expand Down
15 changes: 15 additions & 0 deletions src/server/streamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,18 @@ export interface StreamableHTTPServerTransportOptions {
*/
onsessioninitialized?: (sessionId: string) => void;

/**
* A callback for session close events
* This is called when the server closes a session due to a DELETE request.
* Useful in cases when you need to clean up resources associated with the session.
* Note that this is different from the transport closing, if you are handling
* HTTP requests from multiple nodes you might want to close each
* StreamableHTTPServerTransport after a request is completed while still keeping the
* session open/running.
* @param sessionId The session ID that was closed
*/
onsessionclosed?: (sessionId: string) => void;

/**
* If true, the server will return JSON responses instead of starting an SSE stream.
* This can be useful for simple request/response scenarios without streaming.
Expand Down Expand Up @@ -127,6 +139,7 @@ export class StreamableHTTPServerTransport implements Transport {
private _standaloneSseStreamId: string = '_GET_stream';
private _eventStore?: EventStore;
private _onsessioninitialized?: (sessionId: string) => void;
private _onsessionclosed?: (sessionId: string) => void;
private _allowedHosts?: string[];
private _allowedOrigins?: string[];
private _enableDnsRebindingProtection: boolean;
Expand All @@ -141,6 +154,7 @@ export class StreamableHTTPServerTransport implements Transport {
this._enableJsonResponse = options.enableJsonResponse ?? false;
this._eventStore = options.eventStore;
this._onsessioninitialized = options.onsessioninitialized;
this._onsessionclosed = options.onsessionclosed;
this._allowedHosts = options.allowedHosts;
this._allowedOrigins = options.allowedOrigins;
this._enableDnsRebindingProtection = options.enableDnsRebindingProtection ?? false;
Expand Down Expand Up @@ -538,6 +552,7 @@ export class StreamableHTTPServerTransport implements Transport {
if (!this.validateProtocolVersion(req, res)) {
return;
}
this._onsessionclosed?.(this.sessionId!);
await this.close();
res.writeHead(200).end();
}
Expand Down