diff --git a/src/client/auth.test.ts b/src/client/auth.test.ts index 97c9e14aa..02026c4f3 100644 --- a/src/client/auth.test.ts +++ b/src/client/auth.test.ts @@ -43,6 +43,64 @@ describe("OAuth Authorization", () => { }); }); + it("returns metadata when first fetch fails but second without MCP header succeeds", async () => { + // Set up a counter to control behavior + let callCount = 0; + + // Mock implementation that changes behavior based on call count + mockFetch.mockImplementation((_url, _options) => { + callCount++; + + if (callCount === 1) { + // First call with MCP header - fail with TypeError (simulating CORS error) + // We need to use TypeError specifically because that's what the implementation checks for + return Promise.reject(new TypeError("Network error")); + } else { + // Second call without header - succeed + return Promise.resolve({ + ok: true, + status: 200, + json: async () => validMetadata + }); + } + }); + + // Should succeed with the second call + const metadata = await discoverOAuthMetadata("https://auth.example.com"); + expect(metadata).toEqual(validMetadata); + + // Verify both calls were made + expect(mockFetch).toHaveBeenCalledTimes(2); + + // Verify first call had MCP header + expect(mockFetch.mock.calls[0][1]?.headers).toHaveProperty("MCP-Protocol-Version"); + }); + + it("throws an error when all fetch attempts fail", async () => { + // Set up a counter to control behavior + let callCount = 0; + + // Mock implementation that changes behavior based on call count + mockFetch.mockImplementation((_url, _options) => { + callCount++; + + if (callCount === 1) { + // First call - fail with TypeError + return Promise.reject(new TypeError("First failure")); + } else { + // Second call - fail with different error + return Promise.reject(new Error("Second failure")); + } + }); + + // Should fail with the second error + await expect(discoverOAuthMetadata("https://auth.example.com")) + .rejects.toThrow("Second failure"); + + // Verify both calls were made + expect(mockFetch).toHaveBeenCalledTimes(2); + }); + it("returns undefined when discovery endpoint returns 404", async () => { mockFetch.mockResolvedValueOnce({ ok: false, diff --git a/src/client/auth.ts b/src/client/auth.ts index c0bac9e8e..c7799429e 100644 --- a/src/client/auth.ts +++ b/src/client/auth.ts @@ -163,11 +163,21 @@ export async function discoverOAuthMetadata( opts?: { protocolVersion?: string }, ): Promise { const url = new URL("/.well-known/oauth-authorization-server", serverUrl); - const response = await fetch(url, { - headers: { - "MCP-Protocol-Version": opts?.protocolVersion ?? LATEST_PROTOCOL_VERSION + let response: Response; + try { + response = await fetch(url, { + headers: { + "MCP-Protocol-Version": opts?.protocolVersion ?? LATEST_PROTOCOL_VERSION + } + }); + } catch (error) { + // CORS errors come back as TypeError + if (error instanceof TypeError) { + response = await fetch(url); + } else { + throw error; } - }); + } if (response.status === 404) { return undefined; diff --git a/src/server/auth/handlers/register.test.ts b/src/server/auth/handlers/register.test.ts index 5faf1a4a4..a961f6543 100644 --- a/src/server/auth/handlers/register.test.ts +++ b/src/server/auth/handlers/register.test.ts @@ -141,6 +141,37 @@ describe('Client Registration Handler', () => { expect(response.status).toBe(201); expect(response.body.client_secret).toBeUndefined(); + expect(response.body.client_secret_expires_at).toBeUndefined(); + }); + + it('sets client_secret_expires_at for public clients only', async () => { + // Test for public client (token_endpoint_auth_method not 'none') + const publicClientMetadata: OAuthClientMetadata = { + redirect_uris: ['https://example.com/callback'], + token_endpoint_auth_method: 'client_secret_basic' + }; + + const publicResponse = await supertest(app) + .post('/register') + .send(publicClientMetadata); + + expect(publicResponse.status).toBe(201); + expect(publicResponse.body.client_secret).toBeDefined(); + expect(publicResponse.body.client_secret_expires_at).toBeDefined(); + + // Test for non-public client (token_endpoint_auth_method is 'none') + const nonPublicClientMetadata: OAuthClientMetadata = { + redirect_uris: ['https://example.com/callback'], + token_endpoint_auth_method: 'none' + }; + + const nonPublicResponse = await supertest(app) + .post('/register') + .send(nonPublicClientMetadata); + + expect(nonPublicResponse.status).toBe(201); + expect(nonPublicResponse.body.client_secret).toBeUndefined(); + expect(nonPublicResponse.body.client_secret_expires_at).toBeUndefined(); }); it('sets expiry based on clientSecretExpirySeconds', async () => { diff --git a/src/server/auth/handlers/register.ts b/src/server/auth/handlers/register.ts index 675e8733c..30b7cdf8f 100644 --- a/src/server/auth/handlers/register.ts +++ b/src/server/auth/handlers/register.ts @@ -75,20 +75,26 @@ export function clientRegistrationHandler({ } const clientMetadata = parseResult.data; + const isPublicClient = clientMetadata.token_endpoint_auth_method === 'none' // Generate client credentials const clientId = crypto.randomUUID(); - const clientSecret = clientMetadata.token_endpoint_auth_method !== 'none' - ? crypto.randomBytes(32).toString('hex') - : undefined; + const clientSecret = isPublicClient + ? undefined + : crypto.randomBytes(32).toString('hex'); const clientIdIssuedAt = Math.floor(Date.now() / 1000); + // Calculate client secret expiry time + const clientsDoExpire = clientSecretExpirySeconds > 0 + const secretExpiryTime = clientsDoExpire ? clientIdIssuedAt + clientSecretExpirySeconds : 0 + const clientSecretExpiresAt = isPublicClient ? undefined : secretExpiryTime + let clientInfo: OAuthClientInformationFull = { ...clientMetadata, client_id: clientId, client_secret: clientSecret, client_id_issued_at: clientIdIssuedAt, - client_secret_expires_at: clientSecretExpirySeconds > 0 ? clientIdIssuedAt + clientSecretExpirySeconds : 0 + client_secret_expires_at: clientSecretExpiresAt, }; clientInfo = await clientsStore.registerClient!(clientInfo); diff --git a/src/server/auth/middleware/bearerAuth.test.ts b/src/server/auth/middleware/bearerAuth.test.ts index 8c0b595e1..da2f58381 100644 --- a/src/server/auth/middleware/bearerAuth.test.ts +++ b/src/server/auth/middleware/bearerAuth.test.ts @@ -55,6 +55,57 @@ describe("requireBearerAuth middleware", () => { expect(mockResponse.status).not.toHaveBeenCalled(); expect(mockResponse.json).not.toHaveBeenCalled(); }); + + it("should reject expired tokens", async () => { + const expiredAuthInfo: AuthInfo = { + token: "expired-token", + clientId: "client-123", + scopes: ["read", "write"], + expiresAt: Math.floor(Date.now() / 1000) - 100, // Token expired 100 seconds ago + }; + mockVerifyAccessToken.mockResolvedValue(expiredAuthInfo); + + mockRequest.headers = { + authorization: "Bearer expired-token", + }; + + const middleware = requireBearerAuth({ provider: mockProvider }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockVerifyAccessToken).toHaveBeenCalledWith("expired-token"); + expect(mockResponse.status).toHaveBeenCalledWith(401); + expect(mockResponse.set).toHaveBeenCalledWith( + "WWW-Authenticate", + expect.stringContaining('Bearer error="invalid_token"') + ); + expect(mockResponse.json).toHaveBeenCalledWith( + expect.objectContaining({ error: "invalid_token", error_description: "Token has expired" }) + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); + + it("should accept non-expired tokens", async () => { + const nonExpiredAuthInfo: AuthInfo = { + token: "valid-token", + clientId: "client-123", + scopes: ["read", "write"], + expiresAt: Math.floor(Date.now() / 1000) + 3600, // Token expires in an hour + }; + mockVerifyAccessToken.mockResolvedValue(nonExpiredAuthInfo); + + mockRequest.headers = { + authorization: "Bearer valid-token", + }; + + const middleware = requireBearerAuth({ provider: mockProvider }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockVerifyAccessToken).toHaveBeenCalledWith("valid-token"); + expect(mockRequest.auth).toEqual(nonExpiredAuthInfo); + expect(nextFunction).toHaveBeenCalled(); + expect(mockResponse.status).not.toHaveBeenCalled(); + expect(mockResponse.json).not.toHaveBeenCalled(); + }); it("should require specific scopes when configured", async () => { const authInfo: AuthInfo = { diff --git a/src/server/auth/middleware/bearerAuth.ts b/src/server/auth/middleware/bearerAuth.ts index 14109e174..cd1b314af 100644 --- a/src/server/auth/middleware/bearerAuth.ts +++ b/src/server/auth/middleware/bearerAuth.ts @@ -55,6 +55,11 @@ export function requireBearerAuth({ provider, requiredScopes = [] }: BearerAuthM } } + // Check if the token is expired + if (!!authInfo.expiresAt && authInfo.expiresAt < Date.now() / 1000) { + throw new InvalidTokenError("Token has expired"); + } + req.auth = authInfo; next(); } catch (error) {