Skip to content
Merged
6 changes: 5 additions & 1 deletion src/client/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,11 @@ export class Client<

override async connect(transport: Transport, options?: RequestOptions): Promise<void> {
await super.connect(transport);

// When transport sessionId is already set this means we are trying to reconnect.
// In this case we don't need to initialize again.
if (transport.sessionId !== undefined) {
return;
}
try {
const result = await this.request(
{
Expand Down
10 changes: 5 additions & 5 deletions src/client/streamableHttp.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ describe("StreamableHTTPClientTransport", () => {
// We expect the 405 error to be caught and handled gracefully
// This should not throw an error that breaks the transport
await transport.start();
await expect(transport["_startOrAuthStandaloneSSE"]({})).resolves.not.toThrow("Failed to open SSE stream: Method Not Allowed");
await expect(transport["_startOrAuthSse"]({})).resolves.not.toThrow("Failed to open SSE stream: Method Not Allowed");
// Check that GET was attempted
expect(global.fetch).toHaveBeenCalledWith(
expect.anything(),
Expand Down Expand Up @@ -208,7 +208,7 @@ describe("StreamableHTTPClientTransport", () => {
transport.onmessage = messageSpy;

await transport.start();
await transport["_startOrAuthStandaloneSSE"]({});
await transport["_startOrAuthSse"]({});

// Give time for the SSE event to be processed
await new Promise(resolve => setTimeout(resolve, 50));
Expand Down Expand Up @@ -313,9 +313,9 @@ describe("StreamableHTTPClientTransport", () => {
await transport.start();
// Type assertion to access private method
const transportWithPrivateMethods = transport as unknown as {
_startOrAuthStandaloneSSE: (options: { lastEventId?: string }) => Promise<void>
_startOrAuthSse: (options: { lastEventId?: string }) => Promise<void>
};
await transportWithPrivateMethods._startOrAuthStandaloneSSE({ lastEventId: "test-event-id" });
await transportWithPrivateMethods._startOrAuthSse({ lastEventId: "test-event-id" });

// Verify fetch was called with the lastEventId header
expect(fetchSpy).toHaveBeenCalled();
Expand Down Expand Up @@ -382,7 +382,7 @@ describe("StreamableHTTPClientTransport", () => {

await transport.start();

await transport["_startOrAuthStandaloneSSE"]({});
await transport["_startOrAuthSse"]({});
expect((actualReqInit.headers as Headers).get("x-custom-header")).toBe("CustomValue");

requestInit.headers["X-Custom-Header"] = "SecondCustomValue";
Expand Down
61 changes: 49 additions & 12 deletions src/client/streamableHttp.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Transport } from "../shared/transport.js";
import { isJSONRPCNotification, JSONRPCMessage, JSONRPCMessageSchema } from "../types.js";
import { isJSONRPCNotification, isJSONRPCRequest, isJSONRPCResponse, JSONRPCMessage, JSONRPCMessageSchema } from "../types.js";
import { auth, AuthResult, OAuthClientProvider, UnauthorizedError } from "./auth.js";
import { EventSourceParserStream } from "eventsource-parser/stream";

Expand Down Expand Up @@ -28,6 +28,14 @@ export interface StartSSEOptions {
* The ID of the last received event, used for resuming a disconnected stream
*/
lastEventId?: string;
/**
* The callback function that is invoked when the last event ID changes
*/
onLastEventIdUpdate?: (event: string) => void
/**
* When reconnecting to a long-running SSE stream, we need to make sure that message id matches
*/
replayMessageId?: string | number;
}

/**
Expand Down Expand Up @@ -88,6 +96,12 @@ export type StreamableHTTPClientTransportOptions = {
* Options to configure the reconnection behavior.
*/
reconnectionOptions?: StreamableHTTPReconnectionOptions;

/**
* Session ID for the connection. This is used to identify the session on the server.
* When not provided and connecting to a server that supports session IDs, the server will generate a new session ID.
*/
sessionId?: string;
};

/**
Expand All @@ -114,6 +128,7 @@ export class StreamableHTTPClientTransport implements Transport {
this._url = url;
this._requestInit = opts?.requestInit;
this._authProvider = opts?.authProvider;
this._sessionId = opts?.sessionId;
this._reconnectionOptions = opts?.reconnectionOptions ?? DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS;
}

Expand All @@ -134,7 +149,7 @@ export class StreamableHTTPClientTransport implements Transport {
throw new UnauthorizedError();
}

return await this._startOrAuthStandaloneSSE({ lastEventId: undefined });
return await this._startOrAuthSse({ lastEventId: undefined });
}

private async _commonHeaders(): Promise<Headers> {
Expand All @@ -156,7 +171,7 @@ export class StreamableHTTPClientTransport implements Transport {
}


private async _startOrAuthStandaloneSSE(options: StartSSEOptions): Promise<void> {
private async _startOrAuthSse(options: StartSSEOptions): Promise<void> {
const { lastEventId } = options;
try {
// Try to open an initial SSE stream with GET to listen for server messages
Expand Down Expand Up @@ -193,7 +208,7 @@ export class StreamableHTTPClientTransport implements Transport {
);
}

this._handleSseStream(response.body);
this._handleSseStream(response.body, options);
} catch (error) {
this.onerror?.(error as Error);
throw error;
Expand Down Expand Up @@ -224,7 +239,7 @@ export class StreamableHTTPClientTransport implements Transport {
* @param lastEventId The ID of the last received event for resumability
* @param attemptCount Current reconnection attempt count for this specific stream
*/
private _scheduleReconnection(lastEventId: string, attemptCount = 0): void {
private _scheduleReconnection(options: StartSSEOptions, attemptCount = 0): void {
// Use provided options or default options
const maxRetries = this._reconnectionOptions.maxRetries;

Expand All @@ -240,18 +255,19 @@ export class StreamableHTTPClientTransport implements Transport {
// Schedule the reconnection
setTimeout(() => {
// Use the last event ID to resume where we left off
this._startOrAuthStandaloneSSE({ lastEventId }).catch(error => {
this._startOrAuthSse(options).catch(error => {
this.onerror?.(new Error(`Failed to reconnect SSE stream: ${error instanceof Error ? error.message : String(error)}`));
// Schedule another attempt if this one failed, incrementing the attempt counter
this._scheduleReconnection(lastEventId, attemptCount + 1);
this._scheduleReconnection(options, attemptCount + 1);
});
}, delay);
}

private _handleSseStream(stream: ReadableStream<Uint8Array> | null): void {
private _handleSseStream(stream: ReadableStream<Uint8Array> | null, options: StartSSEOptions): void {
if (!stream) {
return;
}
const { onLastEventIdUpdate, replayMessageId } = options;

let lastEventId: string | undefined;
const processStream = async () => {
Expand All @@ -274,11 +290,15 @@ export class StreamableHTTPClientTransport implements Transport {
// Update last event ID if provided
if (event.id) {
lastEventId = event.id;
onLastEventIdUpdate?.(lastEventId);
}

if (!event.event || event.event === "message") {
try {
const message = JSONRPCMessageSchema.parse(JSON.parse(event.data));
if (replayMessageId !== undefined && isJSONRPCResponse(message)) {
message.id = replayMessageId;
}
this.onmessage?.(message);
} catch (error) {
this.onerror?.(error as Error);
Expand All @@ -294,7 +314,7 @@ export class StreamableHTTPClientTransport implements Transport {
// Use the exponential backoff reconnection strategy
if (lastEventId !== undefined) {
try {
this._scheduleReconnection(lastEventId, 0);
this._scheduleReconnection(options, 0);
}
catch (error) {
this.onerror?.(new Error(`Failed to reconnect: ${error instanceof Error ? error.message : String(error)}`));
Expand Down Expand Up @@ -338,8 +358,18 @@ export class StreamableHTTPClientTransport implements Transport {
this.onclose?.();
}

async send(message: JSONRPCMessage | JSONRPCMessage[]): Promise<void> {
async send(message: JSONRPCMessage | JSONRPCMessage[], options?: { resumptionToken?: string, onresumptiontoken?: (event: string) => void }): Promise<void> {
try {
// If client passes in a lastEventId in the request options, we need to reconnect the SSE stream
const lastEventId = options?.resumptionToken
const onLastEventIdUpdate = options?.onresumptiontoken;
if (lastEventId) {

// If we have at last event ID, we need to reconnect the SSE stream
this._startOrAuthSse({ lastEventId, replayMessageId: isJSONRPCRequest(message) ? message.id : undefined }).catch(err => this.onerror?.(err));
return;
}

const headers = await this._commonHeaders();
headers.set("content-type", "application/json");
headers.set("accept", "application/json, text/event-stream");
Expand Down Expand Up @@ -383,7 +413,7 @@ export class StreamableHTTPClientTransport implements Transport {
// if it's supported by the server
if (isJSONRPCNotification(message) && message.method === "notifications/initialized") {
// Start without a lastEventId since this is a fresh connection
this._startOrAuthStandaloneSSE({ lastEventId: undefined }).catch(err => this.onerror?.(err));
this._startOrAuthSse({ lastEventId: undefined }).catch(err => this.onerror?.(err));
}
return;
}
Expand All @@ -398,7 +428,10 @@ export class StreamableHTTPClientTransport implements Transport {

if (hasRequests) {
if (contentType?.includes("text/event-stream")) {
this._handleSseStream(response.body);
// Handle SSE stream responses for requests
// We use the same handler as standalone streams, which now supports
// reconnection with the last event ID
this._handleSseStream(response.body, { onLastEventIdUpdate });
} else if (contentType?.includes("application/json")) {
// For non-streaming servers, we might get direct JSON responses
const data = await response.json();
Expand All @@ -421,4 +454,8 @@ export class StreamableHTTPClientTransport implements Transport {
throw error;
}
}

get sessionId(): string | undefined {
return this._sessionId;
}
}
18 changes: 15 additions & 3 deletions src/examples/client/simpleStreamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ let notificationCount = 0;
let client: Client | null = null;
let transport: StreamableHTTPClientTransport | null = null;
let serverUrl = 'http://localhost:3000/mcp';
let notificationsToolLastEventId: string | undefined = undefined;
let sessionId: string | undefined = undefined;

async function main(): Promise<void> {
console.log('MCP Interactive Client');
Expand Down Expand Up @@ -109,7 +111,7 @@ function commandLoop(): void {

case 'start-notifications': {
const interval = args[1] ? parseInt(args[1], 10) : 2000;
const count = args[2] ? parseInt(args[2], 10) : 0;
const count = args[2] ? parseInt(args[2], 10) : 10;
await startNotifications(interval, count);
break;
}
Expand Down Expand Up @@ -186,7 +188,10 @@ async function connect(url?: string): Promise<void> {
}

transport = new StreamableHTTPClientTransport(
new URL(serverUrl)
new URL(serverUrl),
{
sessionId: sessionId
}
);

// Set up notification handlers
Expand Down Expand Up @@ -218,6 +223,8 @@ async function connect(url?: string): Promise<void> {

// Connect the client
await client.connect(transport);
sessionId = transport.sessionId
console.log('Transport created with session ID:', sessionId);
console.log('Connected to MCP server');
} catch (error) {
console.error('Failed to connect:', error);
Expand Down Expand Up @@ -291,7 +298,12 @@ async function callTool(name: string, args: Record<string, unknown>): Promise<vo
};

console.log(`Calling tool '${name}' with args:`, args);
const result = await client.request(request, CallToolResultSchema);
const onLastEventIdUpdate = (event: string) => {
notificationsToolLastEventId = event;
};
const result = await client.request(request, CallToolResultSchema, {
resumptionToken: notificationsToolLastEventId, onresumptiontoken: onLastEventIdUpdate
});

console.log('Tool result:');
result.content.forEach(item => {
Expand Down
20 changes: 12 additions & 8 deletions src/examples/server/simpleStreamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,18 @@ server.tool(

while (count === 0 || counter < count) {
counter++;
await sendNotification({
method: "notifications/message",
params: {
level: "info",
data: `Periodic notification #${counter} at ${new Date().toISOString()}`
}
});

try {
await sendNotification({
method: "notifications/message",
params: {
level: "info",
data: `Periodic notification #${counter} at ${new Date().toISOString()}`
}
});
}
catch (error) {
console.error("Error sending notification:", error);
}
// Wait for the specified interval
await sleep(interval);
}
Expand Down
Loading