Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,14 @@ private MethodSpec authSchemeWithEndpointSignerPropertiesMethod() {

method.beginControlFlow("for ($T endpointAuthScheme : endpointAuthSchemes)", EndpointAuthScheme.class);

if (useSraAuth) {
// Don't include signer properties for auth options that don't match our selected auth scheme
method.beginControlFlow("if (!endpointAuthScheme.schemeId()"
+ ".equals(selectedAuthScheme.authSchemeOption().schemeId()))");
method.addStatement("continue");
method.endControlFlow();
}

method.addStatement("$T option = selectedAuthScheme.authSchemeOption().toBuilder()", AuthSchemeOption.Builder.class);

if (dependsOnHttpAuthAws) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,27 @@
import static software.amazon.awssdk.codegen.poet.PoetMatchers.generatesTo;

import org.junit.jupiter.api.Test;
import software.amazon.awssdk.codegen.model.intermediate.IntermediateModel;
import software.amazon.awssdk.codegen.poet.ClassSpec;
import software.amazon.awssdk.codegen.poet.ClientTestModels;

public class EndpointResolverInterceptorSpecTest {
@Test
public void endpointResolverInterceptorClass() {
ClassSpec endpointProviderInterceptor = new EndpointResolverInterceptorSpec(ClientTestModels.queryServiceModels());
ClassSpec endpointProviderInterceptor = new EndpointResolverInterceptorSpec(getModel(true));
assertThat(endpointProviderInterceptor, generatesTo("endpoint-resolve-interceptor.java"));
}

// TODO(post-sra-identity-auth): This can be deleted when useSraAuth is removed
@Test
public void endpointResolverInterceptorClass_preSra() {
ClassSpec endpointProviderInterceptor = new EndpointResolverInterceptorSpec(getModel(false));
assertThat(endpointProviderInterceptor, generatesTo("endpoint-resolve-interceptor-preSra.java"));
}

private static IntermediateModel getModel(boolean useSraAuth) {
IntermediateModel model = ClientTestModels.queryServiceModels();
model.getCustomizationConfig().setUseSraAuth(useSraAuth);
return model;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
package software.amazon.awssdk.services.query.endpoints.internal;

import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletionException;
import java.util.function.Supplier;
import software.amazon.awssdk.annotations.Generated;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.auth.signer.Aws4Signer;
import software.amazon.awssdk.auth.signer.SignerLoader;
import software.amazon.awssdk.awscore.AwsExecutionAttribute;
import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute;
import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme;
import software.amazon.awssdk.awscore.endpoints.authscheme.SigV4AuthScheme;
import software.amazon.awssdk.awscore.endpoints.authscheme.SigV4aAuthScheme;
import software.amazon.awssdk.awscore.util.SignerOverrideUtils;
import software.amazon.awssdk.core.SdkRequest;
import software.amazon.awssdk.core.SelectedAuthScheme;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.core.interceptor.Context;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute;
import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute;
import software.amazon.awssdk.core.signer.Signer;
import software.amazon.awssdk.endpoints.Endpoint;
import software.amazon.awssdk.http.SdkHttpRequest;
import software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner;
import software.amazon.awssdk.http.auth.aws.signer.AwsV4aHttpSigner;
import software.amazon.awssdk.http.auth.aws.signer.RegionSet;
import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption;
import software.amazon.awssdk.identity.spi.Identity;
import software.amazon.awssdk.services.query.endpoints.QueryClientContextParams;
import software.amazon.awssdk.services.query.endpoints.QueryEndpointParams;
import software.amazon.awssdk.services.query.endpoints.QueryEndpointProvider;
import software.amazon.awssdk.services.query.model.OperationWithContextParamRequest;
import software.amazon.awssdk.utils.AttributeMap;

@Generated("software.amazon.awssdk:codegen")
@SdkInternalApi
public final class QueryResolveEndpointInterceptor implements ExecutionInterceptor {
@Override
public SdkRequest modifyRequest(Context.ModifyRequest context, ExecutionAttributes executionAttributes) {
SdkRequest result = context.request();
if (AwsEndpointProviderUtils.endpointIsDiscovered(executionAttributes)) {
return result;
}
QueryEndpointProvider provider = (QueryEndpointProvider) executionAttributes
.getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER);
try {
Endpoint endpoint = provider.resolveEndpoint(ruleParams(result, executionAttributes)).join();
if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) {
Optional<String> hostPrefix = hostPrefix(executionAttributes.getAttribute(SdkExecutionAttribute.OPERATION_NAME),
result);
if (hostPrefix.isPresent()) {
endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get());
}
}
List<EndpointAuthScheme> endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES);
SelectedAuthScheme<?> selectedAuthScheme = executionAttributes
.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME);
if (endpointAuthSchemes != null && selectedAuthScheme != null) {
selectedAuthScheme = authSchemeWithEndpointSignerProperties(endpointAuthSchemes, selectedAuthScheme);
executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme);
}
if (endpointAuthSchemes != null) {
EndpointAuthScheme chosenAuthScheme = AuthSchemeUtils.chooseAuthScheme(endpointAuthSchemes);
Supplier<Signer> signerProvider = signerProvider(chosenAuthScheme);
result = SignerOverrideUtils.overrideSignerIfNotOverridden(result, executionAttributes, signerProvider);
}
executionAttributes.putAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT, endpoint);
return result;
} catch (CompletionException e) {
Throwable cause = e.getCause();
if (cause instanceof SdkClientException) {
throw (SdkClientException) cause;
} else {
throw SdkClientException.create("Endpoint resolution failed", cause);
}
}
}

@Override
public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) {
Endpoint resolvedEndpoint = executionAttributes.getAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT);
if (resolvedEndpoint.headers().isEmpty()) {
return context.httpRequest();
}
SdkHttpRequest.Builder httpRequestBuilder = context.httpRequest().toBuilder();
resolvedEndpoint.headers().forEach((name, values) -> {
values.forEach(v -> httpRequestBuilder.appendHeader(name, v));
});
return httpRequestBuilder.build();
}

public static QueryEndpointParams ruleParams(SdkRequest request, ExecutionAttributes executionAttributes) {
QueryEndpointParams.Builder builder = QueryEndpointParams.builder();
builder.region(AwsEndpointProviderUtils.regionBuiltIn(executionAttributes));
builder.useDualStackEndpoint(AwsEndpointProviderUtils.dualStackEnabledBuiltIn(executionAttributes));
builder.useFipsEndpoint(AwsEndpointProviderUtils.fipsEnabledBuiltIn(executionAttributes));
setClientContextParams(builder, executionAttributes);
setContextParams(builder, executionAttributes.getAttribute(AwsExecutionAttribute.OPERATION_NAME), request);
setStaticContextParams(builder, executionAttributes.getAttribute(AwsExecutionAttribute.OPERATION_NAME));
return builder.build();
}

private static void setContextParams(QueryEndpointParams.Builder params, String operationName, SdkRequest request) {
switch (operationName) {
case "OperationWithContextParam":
setContextParams(params, (OperationWithContextParamRequest) request);
break;
default:
break;
}
}

private static void setContextParams(QueryEndpointParams.Builder params, OperationWithContextParamRequest request) {
params.operationContextParam(request.stringMember());
}

private static void setStaticContextParams(QueryEndpointParams.Builder params, String operationName) {
switch (operationName) {
case "OperationWithStaticContextParams":
operationWithStaticContextParamsStaticContextParams(params);
break;
default:
break;
}
}

private static void operationWithStaticContextParamsStaticContextParams(QueryEndpointParams.Builder params) {
params.staticStringParam("hello");
}

private <T extends Identity> SelectedAuthScheme<T> authSchemeWithEndpointSignerProperties(
List<EndpointAuthScheme> endpointAuthSchemes, SelectedAuthScheme<T> selectedAuthScheme) {
for (EndpointAuthScheme endpointAuthScheme : endpointAuthSchemes) {
AuthSchemeOption.Builder option = selectedAuthScheme.authSchemeOption().toBuilder();
if (endpointAuthScheme instanceof SigV4AuthScheme) {
SigV4AuthScheme v4AuthScheme = (SigV4AuthScheme) endpointAuthScheme;
if (v4AuthScheme.isDisableDoubleEncodingSet()) {
option.putSignerProperty(AwsV4HttpSigner.DOUBLE_URL_ENCODE, !v4AuthScheme.disableDoubleEncoding());
}
if (v4AuthScheme.signingRegion() != null) {
option.putSignerProperty(AwsV4HttpSigner.REGION_NAME, v4AuthScheme.signingRegion());
}
if (v4AuthScheme.signingName() != null) {
option.putSignerProperty(AwsV4HttpSigner.SERVICE_SIGNING_NAME, v4AuthScheme.signingName());
}
return new SelectedAuthScheme<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), option.build());
}
if (endpointAuthScheme instanceof SigV4aAuthScheme) {
SigV4aAuthScheme v4aAuthScheme = (SigV4aAuthScheme) endpointAuthScheme;
if (v4aAuthScheme.isDisableDoubleEncodingSet()) {
option.putSignerProperty(AwsV4aHttpSigner.DOUBLE_URL_ENCODE, !v4aAuthScheme.disableDoubleEncoding());
}
if (v4aAuthScheme.signingRegionSet() != null) {
RegionSet regionSet = RegionSet.create(v4aAuthScheme.signingRegionSet());
option.putSignerProperty(AwsV4aHttpSigner.REGION_SET, regionSet);
}
if (v4aAuthScheme.signingName() != null) {
option.putSignerProperty(AwsV4aHttpSigner.SERVICE_SIGNING_NAME, v4aAuthScheme.signingName());
}
return new SelectedAuthScheme<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), option.build());
}
throw new IllegalArgumentException("Endpoint auth scheme '" + endpointAuthScheme.name()
+ "' cannot be mapped to the SDK auth scheme. Was it declared in the service's model?");
}
return selectedAuthScheme;
}

private static void setClientContextParams(QueryEndpointParams.Builder params, ExecutionAttributes executionAttributes) {
AttributeMap clientContextParams = executionAttributes.getAttribute(SdkInternalExecutionAttribute.CLIENT_CONTEXT_PARAMS);
Optional.ofNullable(clientContextParams.get(QueryClientContextParams.BOOLEAN_CONTEXT_PARAM)).ifPresent(
params::booleanContextParam);
Optional.ofNullable(clientContextParams.get(QueryClientContextParams.STRING_CONTEXT_PARAM)).ifPresent(
params::stringContextParam);
}

private static Optional<String> hostPrefix(String operationName, SdkRequest request) {
switch (operationName) {
case "APostOperation": {
return Optional.of("foo-");
}
default:
return Optional.empty();
}
}

private Supplier<Signer> signerProvider(EndpointAuthScheme authScheme) {
switch (authScheme.name()) {
case "sigv4":
return Aws4Signer::create;
case "sigv4a":
return SignerLoader::getSigV4aSigner;
default:
break;
}
throw SdkClientException.create("Don't know how to create signer for auth scheme: " + authScheme.name());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,13 @@
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletionException;
import java.util.function.Supplier;
import software.amazon.awssdk.annotations.Generated;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.auth.signer.Aws4Signer;
import software.amazon.awssdk.auth.signer.SignerLoader;
import software.amazon.awssdk.awscore.AwsExecutionAttribute;
import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute;
import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme;
import software.amazon.awssdk.awscore.endpoints.authscheme.SigV4AuthScheme;
import software.amazon.awssdk.awscore.endpoints.authscheme.SigV4aAuthScheme;
import software.amazon.awssdk.awscore.util.SignerOverrideUtils;
import software.amazon.awssdk.core.SdkRequest;
import software.amazon.awssdk.core.SelectedAuthScheme;
import software.amazon.awssdk.core.exception.SdkClientException;
Expand All @@ -22,7 +18,6 @@
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute;
import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute;
import software.amazon.awssdk.core.signer.Signer;
import software.amazon.awssdk.endpoints.Endpoint;
import software.amazon.awssdk.http.SdkHttpRequest;
import software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner;
Expand Down Expand Up @@ -63,11 +58,6 @@ public SdkRequest modifyRequest(Context.ModifyRequest context, ExecutionAttribut
selectedAuthScheme = authSchemeWithEndpointSignerProperties(endpointAuthSchemes, selectedAuthScheme);
executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme);
}
if (endpointAuthSchemes != null) {
EndpointAuthScheme chosenAuthScheme = AuthSchemeUtils.chooseAuthScheme(endpointAuthSchemes);
Supplier<Signer> signerProvider = signerProvider(chosenAuthScheme);
result = SignerOverrideUtils.overrideSignerIfNotOverridden(result, executionAttributes, signerProvider);
}
executionAttributes.putAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT, endpoint);
return result;
} catch (CompletionException e) {
Expand Down Expand Up @@ -135,6 +125,9 @@ private static void operationWithStaticContextParamsStaticContextParams(QueryEnd
private <T extends Identity> SelectedAuthScheme<T> authSchemeWithEndpointSignerProperties(
List<EndpointAuthScheme> endpointAuthSchemes, SelectedAuthScheme<T> selectedAuthScheme) {
for (EndpointAuthScheme endpointAuthScheme : endpointAuthSchemes) {
if (!endpointAuthScheme.schemeId().equals(selectedAuthScheme.authSchemeOption().schemeId())) {
continue;
}
AuthSchemeOption.Builder option = selectedAuthScheme.authSchemeOption().toBuilder();
if (endpointAuthScheme instanceof SigV4AuthScheme) {
SigV4AuthScheme v4AuthScheme = (SigV4AuthScheme) endpointAuthScheme;
Expand Down Expand Up @@ -186,16 +179,4 @@ private static Optional<String> hostPrefix(String operationName, SdkRequest requ
return Optional.empty();
}
}

private Supplier<Signer> signerProvider(EndpointAuthScheme authScheme) {
switch (authScheme.name()) {
case "sigv4":
return Aws4Signer::create;
case "sigv4a":
return SignerLoader::getSigV4aSigner;
default:
break;
}
throw SdkClientException.create("Don't know how to create signer for auth scheme: " + authScheme.name());
}
}