1717package org .springframework .security .oauth2 .client .web .function .client ;
1818
1919import java .io .IOException ;
20- import java .net .URI ;
2120import java .util .HashMap ;
2221import java .util .Map ;
2322import java .util .function .Consumer ;
2625import jakarta .servlet .http .HttpServletResponse ;
2726
2827import org .springframework .http .HttpHeaders ;
29- import org .springframework .http .HttpMethod ;
3028import org .springframework .http .HttpRequest ;
3129import org .springframework .http .HttpStatus ;
3230import org .springframework .http .HttpStatusCode ;
4745import org .springframework .security .oauth2 .client .OAuth2AuthorizedClientProvider ;
4846import org .springframework .security .oauth2 .client .OAuth2AuthorizedClientService ;
4947import org .springframework .security .oauth2 .client .RemoveAuthorizedClientOAuth2AuthorizationFailureHandler ;
48+ import org .springframework .security .oauth2 .client .authentication .OAuth2AuthenticationToken ;
5049import org .springframework .security .oauth2 .client .registration .ClientRegistration ;
5150import org .springframework .security .oauth2 .client .web .OAuth2AuthorizedClientRepository ;
5251import org .springframework .security .oauth2 .core .OAuth2AuthorizationException ;
5554import org .springframework .security .oauth2 .core .endpoint .OAuth2ParameterNames ;
5655import org .springframework .util .Assert ;
5756import org .springframework .util .StringUtils ;
58- import org .springframework .web .client .DefaultResponseErrorHandler ;
59- import org .springframework .web .client .ResponseErrorHandler ;
57+ import org .springframework .web .client .RestClient ;
6058import org .springframework .web .client .RestClientResponseException ;
6159import org .springframework .web .context .request .RequestContextHolder ;
6260import org .springframework .web .context .request .ServletRequestAttributes ;
@@ -116,9 +114,12 @@ public final class OAuth2ClientHttpRequestInterceptor implements ClientHttpReque
116114 private static final Authentication ANONYMOUS_AUTHENTICATION = new AnonymousAuthenticationToken ("anonymous" ,
117115 "anonymousUser" , AuthorityUtils .createAuthorityList ("ROLE_ANONYMOUS" ));
118116
117+ private static final String CLIENT_REGISTRATION_ID_ATTR_NAME = OAuth2ClientHttpRequestInterceptor .class .getName ()
118+ .concat (".clientRegistrationId" );
119+
119120 private final OAuth2AuthorizedClientManager authorizedClientManager ;
120121
121- private final String clientRegistrationId ;
122+ private String defaultClientRegistrationId ;
122123
123124 // @formatter:off
124125 private OAuth2AuthorizationFailureHandler authorizationFailureHandler =
@@ -133,15 +134,27 @@ public final class OAuth2ClientHttpRequestInterceptor implements ClientHttpReque
133134 * parameters.
134135 * @param authorizedClientManager the {@link OAuth2AuthorizedClientManager} which
135136 * manages the authorized client(s)
136- * @param clientRegistrationId the {@link ClientRegistration#getRegistrationId()} to
137- * be used to look up the {@link OAuth2AuthorizedClient}
138137 */
139- public OAuth2ClientHttpRequestInterceptor (OAuth2AuthorizedClientManager authorizedClientManager ,
140- String clientRegistrationId ) {
138+ public OAuth2ClientHttpRequestInterceptor (OAuth2AuthorizedClientManager authorizedClientManager ) {
141139 Assert .notNull (authorizedClientManager , "authorizedClientManager cannot be null" );
142- Assert .hasText (clientRegistrationId , "clientRegistrationId cannot be empty" );
143140 this .authorizedClientManager = authorizedClientManager ;
144- this .clientRegistrationId = clientRegistrationId ;
141+ }
142+
143+ /**
144+ * Sets the default {@code clientRegistrationId} to be used for resolving an
145+ * {@link OAuth2AuthorizedClient}.
146+ *
147+ * <p>
148+ * By default, the {@code clientRegistrationId} is obtained from the current
149+ * {@link Authentication principal}. Using this setter overrides the default, but can
150+ * be overridden by providing an
151+ * {@link RestClient.RequestHeadersSpec#attributes(Consumer) attribute} via
152+ * {@link #clientRegistrationId(String)}.
153+ * @param clientRegistrationId the default {@code clientRegistrationId}
154+ */
155+ public void setDefaultClientRegistrationId (String clientRegistrationId ) {
156+ Assert .hasText (clientRegistrationId , "clientRegistrationId cannot be empty" );
157+ this .defaultClientRegistrationId = clientRegistrationId ;
145158 }
146159
147160 /**
@@ -237,33 +250,52 @@ public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy secur
237250 this .securityContextHolderStrategy = securityContextHolderStrategy ;
238251 }
239252
253+ /**
254+ * Modifies the {@link ClientHttpRequest#getAttributes() attributes} to include the
255+ * {@link ClientRegistration#getRegistrationId() clientRegistrationId} to be used to
256+ * look up the {@link OAuth2AuthorizedClient}.
257+ * @param clientRegistrationId the {@link ClientRegistration#getRegistrationId()
258+ * clientRegistrationId} to be used to look up the {@link OAuth2AuthorizedClient}
259+ * @return the {@link Consumer} to populate the attributes
260+ */
261+ public static Consumer <Map <String , Object >> clientRegistrationId (String clientRegistrationId ) {
262+ Assert .hasText (clientRegistrationId , "clientRegistrationId cannot be empty" );
263+ return (attributes ) -> attributes .put (CLIENT_REGISTRATION_ID_ATTR_NAME , clientRegistrationId );
264+ }
265+
240266 @ Override
241267 public ClientHttpResponse intercept (HttpRequest request , byte [] body , ClientHttpRequestExecution execution )
242268 throws IOException {
243- authorizeClient (request );
269+ Authentication principal = this .securityContextHolderStrategy .getContext ().getAuthentication ();
270+ if (principal == null ) {
271+ principal = ANONYMOUS_AUTHENTICATION ;
272+ }
273+
274+ authorizeClient (request , principal );
244275 try {
245276 ClientHttpResponse response = execution .execute (request , body );
246- handleAuthorizationFailure (response .getHeaders (), response .getStatusCode ());
277+ handleAuthorizationFailure (request , principal , response .getHeaders (), response .getStatusCode ());
247278 return response ;
248279 }
249280 catch (RestClientResponseException ex ) {
250- handleAuthorizationFailure (ex .getResponseHeaders (), ex .getStatusCode ());
281+ handleAuthorizationFailure (request , principal , ex .getResponseHeaders (), ex .getStatusCode ());
251282 throw ex ;
252283 }
253284 catch (OAuth2AuthorizationException ex ) {
254- handleAuthorizationFailure (ex );
285+ handleAuthorizationFailure (ex , principal );
255286 throw ex ;
256287 }
257288 }
258289
259- private void authorizeClient (HttpRequest request ) {
260- Authentication principal = this . securityContextHolderStrategy . getContext (). getAuthentication ( );
261- if (principal == null ) {
262- principal = ANONYMOUS_AUTHENTICATION ;
290+ private void authorizeClient (HttpRequest request , Authentication principal ) {
291+ String clientRegistrationId = clientRegistrationId ( request , principal );
292+ if (clientRegistrationId == null ) {
293+ return ;
263294 }
295+
264296 // @formatter:off
265297 OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
266- .withClientRegistrationId (this . clientRegistrationId )
298+ .withClientRegistrationId (clientRegistrationId )
267299 .principal (principal )
268300 .build ();
269301 // @formatter:on
@@ -273,15 +305,21 @@ private void authorizeClient(HttpRequest request) {
273305 }
274306 }
275307
276- private void handleAuthorizationFailure (HttpHeaders headers , HttpStatusCode httpStatus ) {
308+ private void handleAuthorizationFailure (HttpRequest request , Authentication principal , HttpHeaders headers ,
309+ HttpStatusCode httpStatus ) {
277310 OAuth2Error error = resolveOAuth2ErrorIfPossible (headers , httpStatus );
278311 if (error == null ) {
279312 return ;
280313 }
281314
315+ String clientRegistrationId = clientRegistrationId (request , principal );
316+ if (clientRegistrationId == null ) {
317+ return ;
318+ }
319+
282320 ClientAuthorizationException authorizationException = new ClientAuthorizationException (error ,
283- this . clientRegistrationId );
284- handleAuthorizationFailure (authorizationException );
321+ clientRegistrationId );
322+ handleAuthorizationFailure (authorizationException , principal );
285323 }
286324
287325 private static OAuth2Error resolveOAuth2ErrorIfPossible (HttpHeaders headers , HttpStatusCode httpStatus ) {
@@ -323,12 +361,20 @@ private static Map<String, String> parseWwwAuthenticateHeader(String wwwAuthenti
323361 return parameters ;
324362 }
325363
326- private void handleAuthorizationFailure (OAuth2AuthorizationException authorizationException ) {
327- Authentication principal = this .securityContextHolderStrategy .getContext ().getAuthentication ();
328- if (principal == null ) {
329- principal = ANONYMOUS_AUTHENTICATION ;
364+ private String clientRegistrationId (HttpRequest request , Authentication principal ) {
365+ String clientRegistrationId = (String ) request .getAttributes ().get (CLIENT_REGISTRATION_ID_ATTR_NAME );
366+ if (clientRegistrationId == null ) {
367+ clientRegistrationId = this .defaultClientRegistrationId ;
368+ }
369+ if (clientRegistrationId == null && principal instanceof OAuth2AuthenticationToken authentication ) {
370+ clientRegistrationId = authentication .getAuthorizedClientRegistrationId ();
330371 }
331372
373+ return clientRegistrationId ;
374+ }
375+
376+ private void handleAuthorizationFailure (OAuth2AuthorizationException authorizationException ,
377+ Authentication principal ) {
332378 ServletRequestAttributes requestAttributes = (ServletRequestAttributes ) RequestContextHolder
333379 .getRequestAttributes ();
334380 Map <String , Object > attributes = new HashMap <>();
0 commit comments