diff --git a/.gitignore b/.gitignore index 429a0375ae..2478cac4b3 100644 --- a/.gitignore +++ b/.gitignore @@ -89,7 +89,7 @@ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: -# .python-version +.python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index 376036e8cf..91f8576d71 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -204,22 +204,19 @@ def __init__( ) self._initialized = False - def _extract_resource_metadata_from_www_auth(self, init_response: httpx.Response) -> str | None: + def _extract_field_from_www_auth(self, init_response: httpx.Response, field_name: str) -> str | None: """ - Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728. + Extract field from WWW-Authenticate header. Returns: - Resource metadata URL if found in WWW-Authenticate header, None otherwise + Field value if found in WWW-Authenticate header, None otherwise """ - if not init_response or init_response.status_code != 401: - return None - www_auth_header = init_response.headers.get("WWW-Authenticate") if not www_auth_header: return None - # Pattern matches: resource_metadata="url" or resource_metadata=url (unquoted) - pattern = r'resource_metadata=(?:"([^"]+)"|([^\s,]+))' + # Pattern matches: field_name="value" or field_name=value (unquoted) + pattern = rf'{field_name}=(?:"([^"]+)"|([^\s,]+))' match = re.search(pattern, www_auth_header) if match: @@ -228,6 +225,27 @@ def _extract_resource_metadata_from_www_auth(self, init_response: httpx.Response return None + def _extract_resource_metadata_from_www_auth(self, init_response: httpx.Response) -> str | None: + """ + Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728. + + Returns: + Resource metadata URL if found in WWW-Authenticate header, None otherwise + """ + if not init_response or init_response.status_code != 401: + return None + + return self._extract_field_from_www_auth(init_response, "resource_metadata") + + def _extract_scope_from_www_auth(self, init_response: httpx.Response) -> str | None: + """ + Extract scope parameter from WWW-Authenticate header as per RFC6750. + + Returns: + Scope string if found in WWW-Authenticate header, None otherwise + """ + return self._extract_field_from_www_auth(init_response, "scope") + async def _discover_protected_resource(self, init_response: httpx.Response) -> httpx.Request: # RFC9728: Try to extract resource_metadata URL from WWW-Authenticate header of the initial response url = self._extract_resource_metadata_from_www_auth(init_response) @@ -248,8 +266,32 @@ async def _handle_protected_resource_response(self, response: httpx.Response) -> self.context.protected_resource_metadata = metadata if metadata.authorization_servers: self.context.auth_server_url = str(metadata.authorization_servers[0]) + except ValidationError: pass + else: + raise OAuthFlowError(f"Protected Resource Metadata request failed: {response.status_code}") + + def _select_scopes(self, init_response: httpx.Response) -> None: + """Select scopes as outlined in the 'Scope Selection Strategy in the MCP spec.""" + # Per MCP spec, scope selection priority order: + # 1. Use scope from WWW-Authenticate header (if provided) + # 2. Use all scopes from PRM scopes_supported (if available) + # 3. Omit scope parameter if neither is available + # + www_authenticate_scope = self._extract_scope_from_www_auth(init_response) + if www_authenticate_scope is not None: + # Priority 1: WWW-Authenticate header scope + self.context.client_metadata.scope = www_authenticate_scope + elif ( + self.context.protected_resource_metadata is not None + and self.context.protected_resource_metadata.scopes_supported is not None + ): + # Priority 2: PRM scopes_supported + self.context.client_metadata.scope = " ".join(self.context.protected_resource_metadata.scopes_supported) + else: + # Priority 3: Omit scope parameter + self.context.client_metadata.scope = None def _get_discovery_urls(self) -> list[str]: """Generate ordered list of (url, type) tuples for discovery attempts.""" @@ -478,9 +520,6 @@ async def _handle_oauth_metadata_response(self, response: httpx.Response) -> Non content = await response.aread() metadata = OAuthMetadata.model_validate_json(content) self.context.oauth_metadata = metadata - # Apply default scope if needed - if self.context.client_metadata.scope is None and metadata.scopes_supported is not None: - self.context.client_metadata.scope = " ".join(metadata.scopes_supported) async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]: """HTTPX auth flow integration.""" @@ -514,7 +553,10 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. discovery_response = yield discovery_request await self._handle_protected_resource_response(discovery_response) - # Step 2: Discover OAuth metadata (with fallback for legacy servers) + # Step 2: Apply scope selection strategy + self._select_scopes(response) + + # Step 3: Discover OAuth metadata (with fallback for legacy servers) discovery_urls = self._get_discovery_urls() for url in discovery_urls: oauth_metadata_request = self._create_oauth_metadata_request(url) @@ -529,16 +571,16 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. elif oauth_metadata_response.status_code < 400 or oauth_metadata_response.status_code >= 500: break # Non-4XX error, stop trying - # Step 3: Register client if needed + # Step 4: Register client if needed registration_request = await self._register_client() if registration_request: registration_response = yield registration_request await self._handle_registration_response(registration_response) - # Step 4: Perform authorization + # Step 5: Perform authorization auth_code, code_verifier = await self._perform_authorization() - # Step 5: Exchange authorization code for tokens + # Step 6: Exchange authorization code for tokens token_request = await self._exchange_token(auth_code, code_verifier) token_response = yield token_request await self._handle_token_response(token_response) @@ -549,3 +591,27 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Retry with new tokens self._add_auth_header(request) yield request + elif response.status_code == 403: + # Step 1: Extract error field from WWW-Authenticate header + error = self._extract_field_from_www_auth(response, "error") + + # Step 2: Check if we need to step-up authorization + if error == "insufficient_scope": + try: + # Step 2a: Update the required scopes + self._select_scopes(response) + + # Step 2b: Perform (re-)authorization + auth_code, code_verifier = await self._perform_authorization() + + # Step 2c: Exchange authorization code for tokens + token_request = await self._exchange_token(auth_code, code_verifier) + token_response = yield token_request + await self._handle_token_response(token_response) + except Exception: + logger.exception("OAuth flow error") + raise + + # Retry with new tokens + self._add_auth_header(request) + yield request diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 6e58e496d3..fb1a93e39e 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -79,6 +79,52 @@ async def callback_handler() -> tuple[str, str | None]: ) +@pytest.fixture +def prm_metadata_response(): + """PRM metadata response with scopes.""" + return httpx.Response( + 200, + content=( + b'{"resource": "https://api.example.com/v1/mcp", ' + b'"authorization_servers": ["https://auth.example.com"], ' + b'"scopes_supported": ["resource:read", "resource:write"]}' + ), + ) + + +@pytest.fixture +def prm_metadata_without_scopes_response(): + """PRM metadata response without scopes.""" + return httpx.Response( + 200, + content=( + b'{"resource": "https://api.example.com/v1/mcp", ' + b'"authorization_servers": ["https://auth.example.com"], ' + b'"scopes_supported": null}' + ), + ) + + +@pytest.fixture +def init_response_with_www_auth_scope(): + """Initial 401 response with WWW-Authenticate header containing scope.""" + return httpx.Response( + 401, + headers={"WWW-Authenticate": 'Bearer scope="special:scope from:www-authenticate"'}, + request=httpx.Request("GET", "https://api.example.com/test"), + ) + + +@pytest.fixture +def init_response_without_www_auth_scope(): + """Initial 401 response without WWW-Authenticate scope.""" + return httpx.Response( + 401, + headers={}, + request=httpx.Request("GET", "https://api.example.com/test"), + ) + + class TestPKCEParameters: """Test PKCE parameter generation.""" @@ -391,6 +437,57 @@ async def test_handle_metadata_response_success(self, oauth_provider: OAuthClien assert oauth_provider.context.oauth_metadata is not None assert str(oauth_provider.context.oauth_metadata.issuer) == "https://auth.example.com/" + @pytest.mark.anyio + async def test_prioritize_www_auth_scope_over_prm( + self, + oauth_provider: OAuthClientProvider, + prm_metadata_response: httpx.Response, + init_response_with_www_auth_scope: httpx.Response, + ): + """Test that WWW-Authenticate scope is prioritized over PRM scopes.""" + # First, process PRM metadata to set protected_resource_metadata with scopes + await oauth_provider._handle_protected_resource_response(prm_metadata_response) + + # Process the scope selection with WWW-Authenticate header + oauth_provider._select_scopes(init_response_with_www_auth_scope) + + # Verify that WWW-Authenticate scope is used (not PRM scopes) + assert oauth_provider.context.client_metadata.scope == "special:scope from:www-authenticate" + + @pytest.mark.anyio + async def test_prioritize_prm_scopes_when_no_www_auth_scope( + self, + oauth_provider: OAuthClientProvider, + prm_metadata_response: httpx.Response, + init_response_without_www_auth_scope: httpx.Response, + ): + """Test that PRM scopes are prioritized when WWW-Authenticate header has no scopes.""" + # Process the PRM metadata to set protected_resource_metadata with scopes + await oauth_provider._handle_protected_resource_response(prm_metadata_response) + + # Process the scope selection without WWW-Authenticate scope + oauth_provider._select_scopes(init_response_without_www_auth_scope) + + # Verify that PRM scopes are used + assert oauth_provider.context.client_metadata.scope == "resource:read resource:write" + + @pytest.mark.anyio + async def test_omit_scope_when_no_prm_scopes_or_www_auth( + self, + oauth_provider: OAuthClientProvider, + prm_metadata_without_scopes_response: httpx.Response, + init_response_without_www_auth_scope: httpx.Response, + ): + """Test that scope is omitted when PRM has no scopes and WWW-Authenticate doesn't specify scope.""" + # Process the PRM metadata without scopes + await oauth_provider._handle_protected_resource_response(prm_metadata_without_scopes_response) + + # Process the scope selection without WWW-Authenticate scope + oauth_provider._select_scopes(init_response_without_www_auth_scope) + + # Verify that scope is omitted + assert oauth_provider.context.client_metadata.scope is None + @pytest.mark.anyio async def test_register_client_request(self, oauth_provider: OAuthClientProvider): """Test client registration request building.""" @@ -761,6 +858,98 @@ async def test_auth_flow_no_unnecessary_retry_after_oauth( # Verify exactly one request was yielded (no double-sending) assert request_yields == 1, f"Expected 1 request yield, got {request_yields}" + @pytest.mark.anyio + async def test_403_insufficient_scope_updates_scope_from_header( + self, + oauth_provider: OAuthClientProvider, + mock_storage: MockTokenStorage, + valid_tokens: OAuthToken, + ): + """Test that 403 response correctly updates scope from WWW-Authenticate header.""" + # Pre-store valid tokens and client info + client_info = OAuthClientInformationFull( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + await mock_storage.set_tokens(valid_tokens) + await mock_storage.set_client_info(client_info) + oauth_provider.context.current_tokens = valid_tokens + oauth_provider.context.token_expiry_time = time.time() + 1800 + oauth_provider.context.client_info = client_info + oauth_provider._initialized = True + + # Original scope + assert oauth_provider.context.client_metadata.scope == "read write" + + redirect_captured = False + captured_state = None + + async def capture_redirect(url: str) -> None: + nonlocal redirect_captured, captured_state + redirect_captured = True + # Verify the new scope is included in authorization URL + assert "scope=admin%3Awrite+admin%3Adelete" in url or "scope=admin:write+admin:delete" in url.replace( + "%3A", ":" + ).replace("+", " ") + # Extract state from redirect URL + from urllib.parse import parse_qs, urlparse + + parsed = urlparse(url) + params = parse_qs(parsed.query) + captured_state = params.get("state", [None])[0] + + oauth_provider.context.redirect_handler = capture_redirect + + # Mock callback + async def mock_callback() -> tuple[str, str | None]: + return "auth_code", captured_state + + oauth_provider.context.callback_handler = mock_callback + + test_request = httpx.Request("GET", "https://api.example.com/mcp") + auth_flow = oauth_provider.async_auth_flow(test_request) + + # First request + request = await auth_flow.__anext__() + + # Send 403 with new scope requirement + response_403 = httpx.Response( + 403, + headers={"WWW-Authenticate": 'Bearer error="insufficient_scope", scope="admin:write admin:delete"'}, + request=request, + ) + + # Trigger step-up - should get token exchange request + token_exchange_request = await auth_flow.asend(response_403) + + # Verify scope was updated + assert oauth_provider.context.client_metadata.scope == "admin:write admin:delete" + assert redirect_captured + + # Complete the flow with successful token response + token_response = httpx.Response( + 200, + json={ + "access_token": "new_token_with_new_scope", + "token_type": "Bearer", + "expires_in": 3600, + "scope": "admin:write admin:delete", + }, + request=token_exchange_request, + ) + + # Should get final retry request + final_request = await auth_flow.asend(token_response) + + # Send success response - flow should complete + success_response = httpx.Response(200, request=final_request) + try: + await auth_flow.asend(success_response) + pytest.fail("Should have stopped after successful response") + except StopAsyncIteration: + pass # Expected + @pytest.mark.parametrize( ( @@ -841,45 +1030,64 @@ def test_build_metadata( ) -class TestProtectedResourceWWWAuthenticate: - """Test RFC9728 WWW-Authenticate header parsing functionality for protected resource.""" +class TestWWWAuthenticate: + """Test WWW-Authenticate header parsing functionality.""" @pytest.mark.parametrize( - "www_auth_header,expected_url", + "www_auth_header,field_name,expected_value", [ - # Quoted URL + # Quoted values + ('Bearer scope="read write"', "scope", "read write"), ( 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"', + "resource_metadata", "https://api.example.com/.well-known/oauth-protected-resource", ), - # Unquoted URL + ('Bearer error="insufficient_scope"', "error", "insufficient_scope"), + # Unquoted values + ("Bearer scope=read", "scope", "read"), ( "Bearer resource_metadata=https://api.example.com/.well-known/oauth-protected-resource", + "resource_metadata", "https://api.example.com/.well-known/oauth-protected-resource", ), - # Complex header with multiple parameters + ("Bearer error=invalid_token", "error", "invalid_token"), + # Multiple parameters with quoted value + ( + 'Bearer realm="api", scope="admin:write resource:read", error="insufficient_scope"', + "scope", + "admin:write resource:read", + ), ( 'Bearer realm="api", resource_metadata="https://api.example.com/.well-known/oauth-protected-resource", ' 'error="insufficient_scope"', + "resource_metadata", "https://api.example.com/.well-known/oauth-protected-resource", ), - # Different URL format - ('Bearer resource_metadata="https://custom.domain.com/metadata"', "https://custom.domain.com/metadata"), - # With path and query params + # Multiple parameters with unquoted value + ('Bearer realm="api", scope=basic', "scope", "basic"), + # Values with special characters + ( + 'Bearer scope="resource:read resource:write user_profile"', + "scope", + "resource:read resource:write user_profile", + ), ( 'Bearer resource_metadata="https://api.example.com/auth/metadata?version=1"', + "resource_metadata", "https://api.example.com/auth/metadata?version=1", ), ], ) - def test_extract_resource_metadata_from_www_auth_valid_cases( + def test_extract_field_from_www_auth_valid_cases( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage, www_auth_header: str, - expected_url: str, + field_name: str, + expected_value: str, ): - """Test extraction of resource_metadata URL from various valid WWW-Authenticate headers.""" + """Test extraction of various fields from valid WWW-Authenticate headers.""" async def redirect_handler(url: str) -> None: pass @@ -901,39 +1109,30 @@ async def callback_handler() -> tuple[str, str | None]: request=httpx.Request("GET", "https://api.example.com/test"), ) - result = provider._extract_resource_metadata_from_www_auth(init_response) - assert result == expected_url + result = provider._extract_field_from_www_auth(init_response, field_name) + assert result == expected_value @pytest.mark.parametrize( - "status_code,www_auth_header,description", + "www_auth_header,field_name,description", [ # No header - (401, None, "no WWW-Authenticate header"), + (None, "scope", "no WWW-Authenticate header"), # Empty header - (401, "", "empty WWW-Authenticate header"), - # Header without resource_metadata - (401, 'Bearer realm="api", error="insufficient_scope"', "no resource_metadata parameter"), - # Malformed header - (401, "Bearer resource_metadata=", "malformed resource_metadata parameter"), - # Non-401 status code - ( - 200, - 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"', - "200 OK response", - ), - ( - 500, - 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"', - "500 error response", - ), + ("", "scope", "empty WWW-Authenticate header"), + # Header without requested field + ('Bearer realm="api", error="insufficient_scope"', "scope", "no scope parameter"), + ('Bearer realm="api", scope="read write"', "resource_metadata", "no resource_metadata parameter"), + # Malformed field (empty value) + ("Bearer scope=", "scope", "malformed scope parameter"), + ("Bearer resource_metadata=", "resource_metadata", "malformed resource_metadata parameter"), ], ) - def test_extract_resource_metadata_from_www_auth_invalid_cases( + def test_extract_field_from_www_auth_invalid_cases( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage, - status_code: int, www_auth_header: str | None, + field_name: str, description: str, ): """Test extraction returns None for invalid cases.""" @@ -954,8 +1153,8 @@ async def callback_handler() -> tuple[str, str | None]: headers = {"WWW-Authenticate": www_auth_header} if www_auth_header is not None else {} init_response = httpx.Response( - status_code=status_code, headers=headers, request=httpx.Request("GET", "https://api.example.com/test") + status_code=401, headers=headers, request=httpx.Request("GET", "https://api.example.com/test") ) - result = provider._extract_resource_metadata_from_www_auth(init_response) + result = provider._extract_field_from_www_auth(init_response, field_name) assert result is None, f"Should return None for {description}"