diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java index ab86cc0041e..1f622bee8c5 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java @@ -23,7 +23,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; import javax.crypto.spec.SecretKeySpec; @@ -78,8 +77,6 @@ public final class OidcIdTokenDecoderFactory implements JwtDecoderFactory jwtDecoders = new ConcurrentHashMap<>(); - private Function> jwtValidatorFactory = new DefaultOidcIdTokenValidatorFactory(); private Function jwsAlgorithmResolver = ( @@ -135,16 +132,15 @@ public static ClaimTypeConverter createDefaultClaimTypeConverter() { @Override public JwtDecoder createDecoder(ClientRegistration clientRegistration) { Assert.notNull(clientRegistration, "clientRegistration cannot be null"); - return this.jwtDecoders.computeIfAbsent(clientRegistration.getRegistrationId(), (key) -> { - NimbusJwtDecoder jwtDecoder = buildDecoder(clientRegistration); - jwtDecoder.setJwtValidator(this.jwtValidatorFactory.apply(clientRegistration)); - Converter, Map> claimTypeConverter = this.claimTypeConverterFactory - .apply(clientRegistration); - if (claimTypeConverter != null) { - jwtDecoder.setClaimSetConverter(claimTypeConverter); - } - return jwtDecoder; - }); + + NimbusJwtDecoder jwtDecoder = buildDecoder(clientRegistration); + jwtDecoder.setJwtValidator(this.jwtValidatorFactory.apply(clientRegistration)); + Converter, Map> claimTypeConverter = this.claimTypeConverterFactory + .apply(clientRegistration); + if (claimTypeConverter != null) { + jwtDecoder.setClaimSetConverter(claimTypeConverter); + } + return jwtDecoder; } private NimbusJwtDecoder buildDecoder(ClientRegistration clientRegistration) { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java index 97676e975da..f342f3e15e1 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java @@ -23,7 +23,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; import javax.crypto.spec.SecretKeySpec; @@ -80,8 +79,6 @@ public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecod private static final ClaimTypeConverter DEFAULT_CLAIM_TYPE_CONVERTER = new ClaimTypeConverter( createDefaultClaimTypeConverters()); - private final Map jwtDecoders = new ConcurrentHashMap<>(); - private Function> jwtValidatorFactory = new DefaultOidcIdTokenValidatorFactory(); private Function jwsAlgorithmResolver = ( @@ -126,16 +123,14 @@ public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecod @Override public ReactiveJwtDecoder createDecoder(ClientRegistration clientRegistration) { Assert.notNull(clientRegistration, "clientRegistration cannot be null"); - return this.jwtDecoders.computeIfAbsent(clientRegistration.getRegistrationId(), (key) -> { - NimbusReactiveJwtDecoder jwtDecoder = buildDecoder(clientRegistration); - jwtDecoder.setJwtValidator(this.jwtValidatorFactory.apply(clientRegistration)); - Converter, Map> claimTypeConverter = this.claimTypeConverterFactory - .apply(clientRegistration); - if (claimTypeConverter != null) { - jwtDecoder.setClaimSetConverter(claimTypeConverter); - } - return jwtDecoder; - }); + NimbusReactiveJwtDecoder jwtDecoder = buildDecoder(clientRegistration); + jwtDecoder.setJwtValidator(this.jwtValidatorFactory.apply(clientRegistration)); + Converter, Map> claimTypeConverter = this.claimTypeConverterFactory + .apply(clientRegistration); + if (claimTypeConverter != null) { + jwtDecoder.setClaimSetConverter(claimTypeConverter); + } + return jwtDecoder; } private NimbusReactiveJwtDecoder buildDecoder(ClientRegistration clientRegistration) { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java index c8727804841..5a3c4c21f67 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java @@ -34,6 +34,7 @@ import org.springframework.security.oauth2.jose.jws.MacAlgorithm; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtDecoder; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -46,6 +47,7 @@ /** * @author Joe Grandja * @author Rafael Dominguez + * @author Ivan Golovko * @since 5.2 */ public class OidcIdTokenDecoderFactoryTests { @@ -177,4 +179,13 @@ public void createDecoderWhenCustomClaimTypeConverterFactorySetThenApplied() { verify(customClaimTypeConverterFactory).apply(same(clientRegistration)); } + @Test + public void createDecoderTwiceWithoutCaching() { + this.idTokenDecoderFactory = new OidcIdTokenDecoderFactory(); + ClientRegistration clientRegistration = this.registration.build(); + JwtDecoder decoder1 = this.idTokenDecoderFactory.createDecoder(clientRegistration); + JwtDecoder decoder2 = this.idTokenDecoderFactory.createDecoder(clientRegistration); + assertThat(decoder1).isNotSameAs(decoder2); + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactoryTests.java index 7adddcc029d..40692816a51 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactoryTests.java @@ -34,6 +34,7 @@ import org.springframework.security.oauth2.jose.jws.MacAlgorithm; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -47,6 +48,7 @@ * @author Joe Grandja * @author Rafael Dominguez * @author Ubaid ur Rehman + * @author Ivan Golovko * @since 5.2 */ public class ReactiveOidcIdTokenDecoderFactoryTests { @@ -177,4 +179,13 @@ public void createDecoderWhenCustomClaimTypeConverterFactorySetThenApplied() { verify(customClaimTypeConverterFactory).apply(same(clientRegistration)); } + @Test + public void createDecoderTwice() { + this.idTokenDecoderFactory = new ReactiveOidcIdTokenDecoderFactory(); + ClientRegistration clientRegistration = this.registration.build(); + ReactiveJwtDecoder decoder1 = this.idTokenDecoderFactory.createDecoder(clientRegistration); + ReactiveJwtDecoder decoder2 = this.idTokenDecoderFactory.createDecoder(clientRegistration); + assertThat(decoder1).isNotSameAs(decoder2); + } + }