diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClient.java index 61e387f408c..4b2600a2111 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClient.java @@ -23,6 +23,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; import org.springframework.web.reactive.function.BodyInserters; import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.util.Assert; import reactor.core.publisher.Mono; import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse; @@ -48,11 +49,18 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClient implements Re private WebClient webClient = WebClient.builder() .build(); + /** + * @param webClient the webClient to set + */ + public void setWebClient(WebClient webClient) { + Assert.notNull(webClient, "webClient cannot be null"); + this.webClient = webClient; + } + @Override public Mono getTokenResponse(OAuth2AuthorizationCodeGrantRequest authorizationGrantRequest) { return Mono.defer(() -> { ClientRegistration clientRegistration = authorizationGrantRequest.getClientRegistration(); - OAuth2AuthorizationExchange authorizationExchange = authorizationGrantRequest.getAuthorizationExchange(); String tokenUri = clientRegistration.getProviderDetails().getTokenUri(); BodyInserters.FormInserter body = body(authorizationExchange); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java index f0185a1d6e2..dc129a5d800 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java @@ -32,11 +32,13 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; +import org.springframework.web.reactive.function.client.WebClient; import java.time.Instant; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.*; /** * @author Rob Winch @@ -259,4 +261,31 @@ private MockResponse jsonResponse(String json) { .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) .setBody(json); } + + @Test(expected=IllegalArgumentException.class) + public void setWebClientNullThenIllegalArgumentException(){ + tokenResponseClient.setWebClient(null); + } + + @Test + public void setCustomWebClientThenCustomWebClientIsUsed() { + WebClient customClient = mock(WebClient.class); + when(customClient.post()).thenReturn(WebClient.builder().build().post()); + + tokenResponseClient.setWebClient(customClient); + + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"openid profile\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + this.clientRegistration.scope("openid", "profile", "email", "address"); + + OAuth2AccessTokenResponse response = this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block(); + + verify(customClient, atLeastOnce()).post(); + } }