diff --git a/README.md b/README.md index a959ae3dd..080532255 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ - [Low-Level Server](#low-level-server) - [Writing MCP Clients](#writing-mcp-clients) - [Server Capabilities](#server-capabilities) + - [Proxy OAuth Server](#proxy-authorization-requests-upstream) ## Overview @@ -489,6 +490,52 @@ const result = await client.callTool({ }); ``` +### Proxy Authorization Requests Upstream + +You can proxy OAuth requests to an external authorization provider: + +```typescript +import express from 'express'; +import { ProxyOAuthServerProvider, mcpAuthRouter } from '@modelcontextprotocol/sdk'; + +const app = express(); + +const proxyProvider = new ProxyOAuthServerProvider({ + endpoints: { + authorizationUrl: "https://auth.external.com/oauth2/v1/authorize", + tokenUrl: "https://auth.external.com/oauth2/v1/token", + revocationUrl: "https://auth.external.com/oauth2/v1/revoke", + }, + verifyAccessToken: async (token) => { + return { + token, + clientId: "123", + scopes: ["openid", "email", "profile"], + } + }, + getClient: async (client_id) => { + return { + client_id, + redirect_uris: ["http://localhost:3000/callback"], + } + } +}) + +app.use(mcpAuthRouter({ + provider: proxyProvider, + issuerUrl: new URL("http://auth.external.com"), + baseUrl: new URL("http://mcp.example.com"), + serviceDocumentationUrl: new URL("https://docs.example.com/"), +})) +``` + +This setup allows you to: +- Forward OAuth requests to an external provider +- Add custom token validation logic +- Manage client registrations +- Provide custom documentation URLs +- Maintain control over the OAuth flow while delegating to an external provider + ## Documentation - [Model Context Protocol documentation](https://modelcontextprotocol.io) diff --git a/src/server/auth/handlers/token.test.ts b/src/server/auth/handlers/token.test.ts index 7d15e44a2..bf41b5ebd 100644 --- a/src/server/auth/handlers/token.test.ts +++ b/src/server/auth/handlers/token.test.ts @@ -7,6 +7,7 @@ import supertest from 'supertest'; import * as pkceChallenge from 'pkce-challenge'; import { InvalidGrantError, InvalidTokenError } from '../errors.js'; import { AuthInfo } from '../types.js'; +import { ProxyOAuthServerProvider } from '../providers/proxyProvider.js'; // Mock pkce-challenge jest.mock('pkce-challenge', () => ({ @@ -280,6 +281,67 @@ describe('Token Handler', () => { expect(response.body.expires_in).toBe(3600); expect(response.body.refresh_token).toBe('mock_refresh_token'); }); + + it('passes through code verifier when using proxy provider', async () => { + const originalFetch = global.fetch; + + try { + global.fetch = jest.fn().mockResolvedValue({ + ok: true, + json: () => Promise.resolve({ + access_token: 'mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'mock_refresh_token' + }) + }); + + const proxyProvider = new ProxyOAuthServerProvider({ + endpoints: { + authorizationUrl: 'https://example.com/authorize', + tokenUrl: 'https://example.com/token' + }, + verifyAccessToken: async (token) => ({ + token, + clientId: 'valid-client', + scopes: ['read', 'write'], + expiresAt: Date.now() / 1000 + 3600 + }), + getClient: async (clientId) => clientId === 'valid-client' ? validClient : undefined + }); + + const proxyApp = express(); + const options: TokenHandlerOptions = { provider: proxyProvider }; + proxyApp.use('/token', tokenHandler(options)); + + const response = await supertest(proxyApp) + .post('/token') + .type('form') + .send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'any_verifier' + }); + + expect(response.status).toBe(200); + expect(response.body.access_token).toBe('mock_access_token'); + + expect(global.fetch).toHaveBeenCalledWith( + 'https://example.com/token', + expect.objectContaining({ + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded' + }, + body: expect.stringContaining('code_verifier=any_verifier') + }) + ); + } finally { + global.fetch = originalFetch; + } + }); }); describe('Refresh token grant', () => { diff --git a/src/server/auth/handlers/token.ts b/src/server/auth/handlers/token.ts index 79312068a..28412a014 100644 --- a/src/server/auth/handlers/token.ts +++ b/src/server/auth/handlers/token.ts @@ -90,13 +90,19 @@ export function tokenHandler({ provider, rateLimit: rateLimitConfig }: TokenHand const { code, code_verifier } = parseResult.data; - // Verify PKCE challenge - const codeChallenge = await provider.challengeForAuthorizationCode(client, code); - if (!(await verifyChallenge(code_verifier, codeChallenge))) { - throw new InvalidGrantError("code_verifier does not match the challenge"); + const skipLocalPkceValidation = provider.skipLocalPkceValidation; + + // Perform local PKCE validation unless explicitly skipped + // (e.g. to validate code_verifier in upstream server) + if (!skipLocalPkceValidation) { + const codeChallenge = await provider.challengeForAuthorizationCode(client, code); + if (!(await verifyChallenge(code_verifier, codeChallenge))) { + throw new InvalidGrantError("code_verifier does not match the challenge"); + } } - const tokens = await provider.exchangeAuthorizationCode(client, code); + // Passes the code_verifier to the provider if PKCE validation didn't occur locally + const tokens = await provider.exchangeAuthorizationCode(client, code, skipLocalPkceValidation ? code_verifier : undefined); res.status(200).json(tokens); break; } diff --git a/src/server/auth/provider.ts b/src/server/auth/provider.ts index 7416c5544..dc186bcaf 100644 --- a/src/server/auth/provider.ts +++ b/src/server/auth/provider.ts @@ -36,7 +36,7 @@ export interface OAuthServerProvider { /** * Exchanges an authorization code for an access token. */ - exchangeAuthorizationCode(client: OAuthClientInformationFull, authorizationCode: string): Promise; + exchangeAuthorizationCode(client: OAuthClientInformationFull, authorizationCode: string, codeVerifier?: string): Promise; /** * Exchanges a refresh token for an access token. @@ -54,4 +54,13 @@ export interface OAuthServerProvider { * If the given token is invalid or already revoked, this method should do nothing. */ revokeToken?(client: OAuthClientInformationFull, request: OAuthTokenRevocationRequest): Promise; + + /** + * Whether to skip local PKCE validation. + * + * If true, the server will not perform PKCE validation locally and will pass the code_verifier to the upstream server. + * + * NOTE: This should only be true if the upstream server is performing the actual PKCE validation. + */ + skipLocalPkceValidation?: boolean; } \ No newline at end of file diff --git a/src/server/auth/providers/proxyProvider.test.ts b/src/server/auth/providers/proxyProvider.test.ts new file mode 100644 index 000000000..6e842ea33 --- /dev/null +++ b/src/server/auth/providers/proxyProvider.test.ts @@ -0,0 +1,286 @@ +import { Response } from "express"; +import { ProxyOAuthServerProvider, ProxyOptions } from "./proxyProvider.js"; +import { AuthInfo } from "../types.js"; +import { OAuthClientInformationFull, OAuthTokens } from "../../../shared/auth.js"; +import { ServerError } from "../errors.js"; +import { InvalidTokenError } from "../errors.js"; +import { InsufficientScopeError } from "../errors.js"; + +describe("Proxy OAuth Server Provider", () => { + // Mock client data + const validClient: OAuthClientInformationFull = { + client_id: "test-client", + client_secret: "test-secret", + redirect_uris: ["https://example.com/callback"], + }; + + // Mock response object + const mockResponse = { + redirect: jest.fn(), + } as unknown as Response; + + // Mock provider functions + const mockVerifyToken = jest.fn(); + const mockGetClient = jest.fn(); + + // Base provider options + const baseOptions: ProxyOptions = { + endpoints: { + authorizationUrl: "https://auth.example.com/authorize", + tokenUrl: "https://auth.example.com/token", + revocationUrl: "https://auth.example.com/revoke", + registrationUrl: "https://auth.example.com/register", + }, + verifyAccessToken: mockVerifyToken, + getClient: mockGetClient, + }; + + let provider: ProxyOAuthServerProvider; + let originalFetch: typeof global.fetch; + + beforeEach(() => { + provider = new ProxyOAuthServerProvider(baseOptions); + originalFetch = global.fetch; + global.fetch = jest.fn(); + + // Setup mock implementations + mockVerifyToken.mockImplementation(async (token: string) => { + if (token === "valid-token") { + return { + token, + clientId: "test-client", + scopes: ["read", "write"], + expiresAt: Date.now() / 1000 + 3600, + } as AuthInfo; + } + throw new InvalidTokenError("Invalid token"); + }); + + mockGetClient.mockImplementation(async (clientId: string) => { + if (clientId === "test-client") { + return validClient; + } + return undefined; + }); + }); + + // Add helper function for failed responses + const mockFailedResponse = () => { + (global.fetch as jest.Mock).mockImplementation(() => + Promise.resolve({ + ok: false, + status: 400, + }) + ); + }; + + afterEach(() => { + global.fetch = originalFetch; + jest.clearAllMocks(); + }); + + describe("authorization", () => { + it("redirects to authorization endpoint with correct parameters", async () => { + await provider.authorize( + validClient, + { + redirectUri: "https://example.com/callback", + codeChallenge: "test-challenge", + state: "test-state", + scopes: ["read", "write"], + }, + mockResponse + ); + + const expectedUrl = new URL("https://auth.example.com/authorize"); + expectedUrl.searchParams.set("client_id", "test-client"); + expectedUrl.searchParams.set("response_type", "code"); + expectedUrl.searchParams.set("redirect_uri", "https://example.com/callback"); + expectedUrl.searchParams.set("code_challenge", "test-challenge"); + expectedUrl.searchParams.set("code_challenge_method", "S256"); + expectedUrl.searchParams.set("state", "test-state"); + expectedUrl.searchParams.set("scope", "read write"); + + expect(mockResponse.redirect).toHaveBeenCalledWith(expectedUrl.toString()); + }); + }); + + describe("token exchange", () => { + const mockTokenResponse: OAuthTokens = { + access_token: "new-access-token", + token_type: "Bearer", + expires_in: 3600, + refresh_token: "new-refresh-token", + }; + + beforeEach(() => { + (global.fetch as jest.Mock).mockImplementation(() => + Promise.resolve({ + ok: true, + json: () => Promise.resolve(mockTokenResponse), + }) + ); + }); + + it("exchanges authorization code for tokens", async () => { + const tokens = await provider.exchangeAuthorizationCode( + validClient, + "test-code", + "test-verifier" + ); + + expect(global.fetch).toHaveBeenCalledWith( + "https://auth.example.com/token", + expect.objectContaining({ + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: expect.stringContaining("grant_type=authorization_code") + }) + ); + expect(tokens).toEqual(mockTokenResponse); + }); + + it("exchanges refresh token for new tokens", async () => { + const tokens = await provider.exchangeRefreshToken( + validClient, + "test-refresh-token", + ["read", "write"] + ); + + expect(global.fetch).toHaveBeenCalledWith( + "https://auth.example.com/token", + expect.objectContaining({ + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: expect.stringContaining("grant_type=refresh_token") + }) + ); + expect(tokens).toEqual(mockTokenResponse); + }); + + }); + + describe("client registration", () => { + it("registers new client", async () => { + const newClient: OAuthClientInformationFull = { + client_id: "new-client", + redirect_uris: ["https://new-client.com/callback"], + }; + + (global.fetch as jest.Mock).mockImplementation(() => + Promise.resolve({ + ok: true, + json: () => Promise.resolve(newClient), + }) + ); + + const result = await provider.clientsStore.registerClient!(newClient); + + expect(global.fetch).toHaveBeenCalledWith( + "https://auth.example.com/register", + expect.objectContaining({ + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(newClient), + }) + ); + expect(result).toEqual(newClient); + }); + + it("handles registration failure", async () => { + mockFailedResponse(); + const newClient: OAuthClientInformationFull = { + client_id: "new-client", + redirect_uris: ["https://new-client.com/callback"], + }; + + await expect( + provider.clientsStore.registerClient!(newClient) + ).rejects.toThrow(ServerError); + }); + }); + + describe("token revocation", () => { + it("revokes token", async () => { + (global.fetch as jest.Mock).mockImplementation(() => + Promise.resolve({ + ok: true, + }) + ); + + await provider.revokeToken!(validClient, { + token: "token-to-revoke", + token_type_hint: "access_token", + }); + + expect(global.fetch).toHaveBeenCalledWith( + "https://auth.example.com/revoke", + expect.objectContaining({ + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: expect.stringContaining("token=token-to-revoke"), + }) + ); + }); + + it("handles revocation failure", async () => { + mockFailedResponse(); + await expect( + provider.revokeToken!(validClient, { + token: "invalid-token", + }) + ).rejects.toThrow(ServerError); + }); + }); + + describe("token verification", () => { + it("verifies valid token", async () => { + const validAuthInfo: AuthInfo = { + token: "valid-token", + clientId: "test-client", + scopes: ["read", "write"], + expiresAt: Date.now() / 1000 + 3600, + }; + mockVerifyToken.mockResolvedValue(validAuthInfo); + + const authInfo = await provider.verifyAccessToken("valid-token"); + expect(authInfo).toEqual(validAuthInfo); + expect(mockVerifyToken).toHaveBeenCalledWith("valid-token"); + }); + + it("passes through InvalidTokenError", async () => { + const error = new InvalidTokenError("Token expired"); + mockVerifyToken.mockRejectedValue(error); + + await expect(provider.verifyAccessToken("invalid-token")) + .rejects.toBe(error); + expect(mockVerifyToken).toHaveBeenCalledWith("invalid-token"); + }); + + it("passes through InsufficientScopeError", async () => { + const error = new InsufficientScopeError("Required scopes: read, write"); + mockVerifyToken.mockRejectedValue(error); + + await expect(provider.verifyAccessToken("token-with-insufficient-scope")) + .rejects.toBe(error); + expect(mockVerifyToken).toHaveBeenCalledWith("token-with-insufficient-scope"); + }); + + it("passes through unexpected errors", async () => { + const error = new Error("Unexpected error"); + mockVerifyToken.mockRejectedValue(error); + + await expect(provider.verifyAccessToken("valid-token")) + .rejects.toBe(error); + expect(mockVerifyToken).toHaveBeenCalledWith("valid-token"); + }); + }); +}); \ No newline at end of file diff --git a/src/server/auth/providers/proxyProvider.ts b/src/server/auth/providers/proxyProvider.ts new file mode 100644 index 000000000..be4503050 --- /dev/null +++ b/src/server/auth/providers/proxyProvider.ts @@ -0,0 +1,226 @@ +import { Response } from "express"; +import { OAuthRegisteredClientsStore } from "../clients.js"; +import { + OAuthClientInformationFull, + OAuthClientInformationFullSchema, + OAuthTokenRevocationRequest, + OAuthTokens, + OAuthTokensSchema, +} from "../../../shared/auth.js"; +import { AuthInfo } from "../types.js"; +import { AuthorizationParams, OAuthServerProvider } from "../provider.js"; +import { ServerError } from "../errors.js"; + +export type ProxyEndpoints = { + authorizationUrl: string; + tokenUrl: string; + revocationUrl?: string; + registrationUrl?: string; +}; + +export type ProxyOptions = { + /** + * Individual endpoint URLs for proxying specific OAuth operations + */ + endpoints: ProxyEndpoints; + + /** + * Function to verify access tokens and return auth info + */ + verifyAccessToken: (token: string) => Promise; + + /** + * Function to fetch client information from the upstream server + */ + getClient: (clientId: string) => Promise; + +}; + +/** + * Implements an OAuth server that proxies requests to another OAuth server. + */ +export class ProxyOAuthServerProvider implements OAuthServerProvider { + protected readonly _endpoints: ProxyEndpoints; + protected readonly _verifyAccessToken: (token: string) => Promise; + protected readonly _getClient: (clientId: string) => Promise; + + skipLocalPkceValidation = true; + + revokeToken?: ( + client: OAuthClientInformationFull, + request: OAuthTokenRevocationRequest + ) => Promise; + + constructor(options: ProxyOptions) { + this._endpoints = options.endpoints; + this._verifyAccessToken = options.verifyAccessToken; + this._getClient = options.getClient; + if (options.endpoints?.revocationUrl) { + this.revokeToken = async ( + client: OAuthClientInformationFull, + request: OAuthTokenRevocationRequest + ) => { + const revocationUrl = this._endpoints.revocationUrl; + + if (!revocationUrl) { + throw new Error("No revocation endpoint configured"); + } + + const params = new URLSearchParams(); + params.set("token", request.token); + params.set("client_id", client.client_id); + if (client.client_secret) { + params.set("client_secret", client.client_secret); + } + if (request.token_type_hint) { + params.set("token_type_hint", request.token_type_hint); + } + + const response = await fetch(revocationUrl, { + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: params.toString(), + }); + + if (!response.ok) { + throw new ServerError(`Token revocation failed: ${response.status}`); + } + } + } + } + + get clientsStore(): OAuthRegisteredClientsStore { + const registrationUrl = this._endpoints.registrationUrl; + return { + getClient: this._getClient, + ...(registrationUrl && { + registerClient: async (client: OAuthClientInformationFull) => { + const response = await fetch(registrationUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(client), + }); + + if (!response.ok) { + throw new ServerError(`Client registration failed: ${response.status}`); + } + + const data = await response.json(); + return OAuthClientInformationFullSchema.parse(data); + } + }) + } + } + + async authorize( + client: OAuthClientInformationFull, + params: AuthorizationParams, + res: Response + ): Promise { + // Start with required OAuth parameters + const targetUrl = new URL(this._endpoints.authorizationUrl); + const searchParams = new URLSearchParams({ + client_id: client.client_id, + response_type: "code", + redirect_uri: params.redirectUri, + code_challenge: params.codeChallenge, + code_challenge_method: "S256" + }); + + // Add optional standard OAuth parameters + if (params.state) searchParams.set("state", params.state); + if (params.scopes?.length) searchParams.set("scope", params.scopes.join(" ")); + + targetUrl.search = searchParams.toString(); + res.redirect(targetUrl.toString()); + } + + async challengeForAuthorizationCode( + _client: OAuthClientInformationFull, + _authorizationCode: string + ): Promise { + // In a proxy setup, we don't store the code challenge ourselves + // Instead, we proxy the token request and let the upstream server validate it + return ""; + } + + async exchangeAuthorizationCode( + client: OAuthClientInformationFull, + authorizationCode: string, + codeVerifier?: string + ): Promise { + const params = new URLSearchParams({ + grant_type: "authorization_code", + client_id: client.client_id, + code: authorizationCode, + }); + + if (client.client_secret) { + params.append("client_secret", client.client_secret); + } + + if (codeVerifier) { + params.append("code_verifier", codeVerifier); + } + + const response = await fetch(this._endpoints.tokenUrl, { + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: params.toString(), + }); + + + if (!response.ok) { + throw new ServerError(`Token exchange failed: ${response.status}`); + } + + const data = await response.json(); + return OAuthTokensSchema.parse(data); + } + + async exchangeRefreshToken( + client: OAuthClientInformationFull, + refreshToken: string, + scopes?: string[] + ): Promise { + + const params = new URLSearchParams({ + grant_type: "refresh_token", + client_id: client.client_id, + refresh_token: refreshToken, + }); + + if (client.client_secret) { + params.set("client_secret", client.client_secret); + } + + if (scopes?.length) { + params.set("scope", scopes.join(" ")); + } + + const response = await fetch(this._endpoints.tokenUrl, { + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: params.toString(), + }); + + if (!response.ok) { + throw new ServerError(`Token refresh failed: ${response.status}`); + } + + const data = await response.json(); + return OAuthTokensSchema.parse(data); + } + + async verifyAccessToken(token: string): Promise { + return this._verifyAccessToken(token); + } +} \ No newline at end of file diff --git a/src/server/auth/router.ts b/src/server/auth/router.ts index 30e22c417..49d451c29 100644 --- a/src/server/auth/router.ts +++ b/src/server/auth/router.ts @@ -17,6 +17,13 @@ export type AuthRouterOptions = { */ issuerUrl: URL; + /** + * The base URL of the authorization server to use for the metadata endpoints. + * + * If not provided, the issuer URL will be used as the base URL. + */ + baseUrl?: URL; + /** * An optional URL of a page containing human-readable information that developers might want or need to know when using the authorization server. */ @@ -41,6 +48,7 @@ export type AuthRouterOptions = { */ export function mcpAuthRouter(options: AuthRouterOptions): RequestHandler { const issuer = options.issuerUrl; + const baseUrl = options.baseUrl; // Technically RFC 8414 does not permit a localhost HTTPS exemption, but this will be necessary for ease of testing if (issuer.protocol !== "https:" && issuer.hostname !== "localhost" && issuer.hostname !== "127.0.0.1") { @@ -62,18 +70,18 @@ export function mcpAuthRouter(options: AuthRouterOptions): RequestHandler { issuer: issuer.href, service_documentation: options.serviceDocumentationUrl?.href, - authorization_endpoint: new URL(authorization_endpoint, issuer).href, + authorization_endpoint: new URL(authorization_endpoint, baseUrl || issuer).href, response_types_supported: ["code"], code_challenge_methods_supported: ["S256"], - token_endpoint: new URL(token_endpoint, issuer).href, + token_endpoint: new URL(token_endpoint, baseUrl || issuer).href, token_endpoint_auth_methods_supported: ["client_secret_post"], grant_types_supported: ["authorization_code", "refresh_token"], - revocation_endpoint: revocation_endpoint ? new URL(revocation_endpoint, issuer).href : undefined, + revocation_endpoint: revocation_endpoint ? new URL(revocation_endpoint, baseUrl || issuer).href : undefined, revocation_endpoint_auth_methods_supported: revocation_endpoint ? ["client_secret_post"] : undefined, - registration_endpoint: registration_endpoint ? new URL(registration_endpoint, issuer).href : undefined, + registration_endpoint: registration_endpoint ? new URL(registration_endpoint, baseUrl || issuer).href : undefined, }; const router = express.Router();