diff --git a/src/client/auth.test.ts b/src/client/auth.test.ts index f2dadbb15..6c924898a 100644 --- a/src/client/auth.test.ts +++ b/src/client/auth.test.ts @@ -10,7 +10,8 @@ import { discoverOAuthProtectedResourceMetadata, extractResourceMetadataUrl, auth, - type OAuthClientProvider + type OAuthClientProvider, + selectClientAuthMethod } from './auth.js'; import { ServerError } from '../server/auth/errors.js'; import { AuthorizationServerMetadata } from '../shared/auth.js'; @@ -881,6 +882,25 @@ describe('OAuth Authorization', () => { }); }); + describe('selectClientAuthMethod', () => { + it('selects the correct client authentication method from client information', () => { + const clientInfo = { + client_id: 'test-client-id', + client_secret: 'test-client-secret', + token_endpoint_auth_method: 'client_secret_basic' + }; + const supportedMethods = ['client_secret_post', 'client_secret_basic', 'none']; + const authMethod = selectClientAuthMethod(clientInfo, supportedMethods); + expect(authMethod).toBe('client_secret_basic'); + }); + it('selects the correct client authentication method from supported methods', () => { + const clientInfo = { client_id: 'test-client-id' }; + const supportedMethods = ['client_secret_post', 'client_secret_basic', 'none']; + const authMethod = selectClientAuthMethod(clientInfo, supportedMethods); + expect(authMethod).toBe('none'); + }); + }); + describe('startAuthorization', () => { const validMetadata = { issuer: 'https://auth.example.com', diff --git a/src/client/auth.ts b/src/client/auth.ts index 1e90f34ba..5e48345a3 100644 --- a/src/client/auth.ts +++ b/src/client/auth.ts @@ -3,6 +3,7 @@ import { LATEST_PROTOCOL_VERSION } from '../types.js'; import { OAuthClientMetadata, OAuthClientInformation, + OAuthClientInformationMixed, OAuthTokens, OAuthMetadata, OAuthClientInformationFull, @@ -56,7 +57,7 @@ export interface OAuthClientProvider { * server, or returns `undefined` if the client is not registered with the * server. */ - clientInformation(): OAuthClientInformation | undefined | Promise; + clientInformation(): OAuthClientInformationMixed | undefined | Promise; /** * If implemented, this permits the OAuth client to dynamically register with @@ -66,7 +67,7 @@ export interface OAuthClientProvider { * This method is not required to be implemented if client information is * statically known (e.g., pre-registered). */ - saveClientInformation?(clientInformation: OAuthClientInformationFull): void | Promise; + saveClientInformation?(clientInformation: OAuthClientInformationMixed): void | Promise; /** * Loads any existing OAuth tokens for the current session, or returns @@ -149,6 +150,10 @@ export class UnauthorizedError extends Error { type ClientAuthMethod = 'client_secret_basic' | 'client_secret_post' | 'none'; +function isClientAuthMethod(method: string): method is ClientAuthMethod { + return ['client_secret_basic', 'client_secret_post', 'none'].includes(method); +} + const AUTHORIZATION_CODE_RESPONSE_TYPE = 'code'; const AUTHORIZATION_CODE_CHALLENGE_METHOD = 'S256'; @@ -164,7 +169,7 @@ const AUTHORIZATION_CODE_CHALLENGE_METHOD = 'S256'; * @param supportedMethods - Authentication methods supported by the authorization server * @returns The selected authentication method */ -function selectClientAuthMethod(clientInformation: OAuthClientInformation, supportedMethods: string[]): ClientAuthMethod { +export function selectClientAuthMethod(clientInformation: OAuthClientInformationMixed, supportedMethods: string[]): ClientAuthMethod { const hasClientSecret = clientInformation.client_secret !== undefined; // If server doesn't specify supported methods, use RFC 6749 defaults @@ -172,6 +177,16 @@ function selectClientAuthMethod(clientInformation: OAuthClientInformation, suppo return hasClientSecret ? 'client_secret_post' : 'none'; } + // Prefer the method returned by the server during client registration if valid and supported + if ( + 'token_endpoint_auth_method' in clientInformation && + clientInformation.token_endpoint_auth_method && + isClientAuthMethod(clientInformation.token_endpoint_auth_method) && + supportedMethods.includes(clientInformation.token_endpoint_auth_method) + ) { + return clientInformation.token_endpoint_auth_method; + } + // Try methods in priority order (most secure first) if (hasClientSecret && supportedMethods.includes('client_secret_basic')) { return 'client_secret_basic'; @@ -793,7 +808,7 @@ export async function startAuthorization( resource }: { metadata?: AuthorizationServerMetadata; - clientInformation: OAuthClientInformation; + clientInformation: OAuthClientInformationMixed; redirectUrl: string | URL; scope?: string; state?: string; @@ -876,7 +891,7 @@ export async function exchangeAuthorization( fetchFn }: { metadata?: AuthorizationServerMetadata; - clientInformation: OAuthClientInformation; + clientInformation: OAuthClientInformationMixed; authorizationCode: string; codeVerifier: string; redirectUri: string | URL; @@ -955,7 +970,7 @@ export async function refreshAuthorization( fetchFn }: { metadata?: AuthorizationServerMetadata; - clientInformation: OAuthClientInformation; + clientInformation: OAuthClientInformationMixed; refreshToken: string; resource?: URL; addClientAuthentication?: OAuthClientProvider['addClientAuthentication']; diff --git a/src/examples/client/simpleOAuthClient.ts b/src/examples/client/simpleOAuthClient.ts index 354886050..fc296bc6a 100644 --- a/src/examples/client/simpleOAuthClient.ts +++ b/src/examples/client/simpleOAuthClient.ts @@ -6,7 +6,7 @@ import { URL } from 'node:url'; import { exec } from 'node:child_process'; import { Client } from '../../client/index.js'; import { StreamableHTTPClientTransport } from '../../client/streamableHttp.js'; -import { OAuthClientInformation, OAuthClientInformationFull, OAuthClientMetadata, OAuthTokens } from '../../shared/auth.js'; +import { OAuthClientInformationMixed, OAuthClientMetadata, OAuthTokens } from '../../shared/auth.js'; import { CallToolRequest, ListToolsRequest, CallToolResultSchema, ListToolsResultSchema } from '../../types.js'; import { OAuthClientProvider, UnauthorizedError } from '../../client/auth.js'; @@ -20,7 +20,7 @@ const CALLBACK_URL = `http://localhost:${CALLBACK_PORT}/callback`; * In production, you should persist tokens securely */ class InMemoryOAuthClientProvider implements OAuthClientProvider { - private _clientInformation?: OAuthClientInformationFull; + private _clientInformation?: OAuthClientInformationMixed; private _tokens?: OAuthTokens; private _codeVerifier?: string; @@ -46,11 +46,11 @@ class InMemoryOAuthClientProvider implements OAuthClientProvider { return this._clientMetadata; } - clientInformation(): OAuthClientInformation | undefined { + clientInformation(): OAuthClientInformationMixed | undefined { return this._clientInformation; } - saveClientInformation(clientInformation: OAuthClientInformationFull): void { + saveClientInformation(clientInformation: OAuthClientInformationMixed): void { this._clientInformation = clientInformation; } diff --git a/src/shared/auth.ts b/src/shared/auth.ts index c5ddbda16..819b33086 100644 --- a/src/shared/auth.ts +++ b/src/shared/auth.ts @@ -226,6 +226,7 @@ export type OAuthErrorResponse = z.infer; export type OAuthClientMetadata = z.infer; export type OAuthClientInformation = z.infer; export type OAuthClientInformationFull = z.infer; +export type OAuthClientInformationMixed = OAuthClientInformation | OAuthClientInformationFull; export type OAuthClientRegistrationError = z.infer; export type OAuthTokenRevocationRequest = z.infer; export type OAuthProtectedResourceMetadata = z.infer;