Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions api/src/main/java/io/grpc/LoadBalancer.java
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,7 @@ public static final class PickResult {
private final Status status;
// True if the result is created by withDrop()
private final boolean drop;
@Nullable private final String authorityOverrideHostname;

private PickResult(
@Nullable Subchannel subchannel, @Nullable ClientStreamTracer.Factory streamTracerFactory,
Expand All @@ -560,6 +561,17 @@ private PickResult(
this.streamTracerFactory = streamTracerFactory;
this.status = checkNotNull(status, "status");
this.drop = drop;
this.authorityOverrideHostname = null;
}

private PickResult(
@Nullable Subchannel subchannel, @Nullable ClientStreamTracer.Factory streamTracerFactory,
Status status, boolean drop, @Nullable String authorityOverrideHostname) {
this.subchannel = subchannel;
this.streamTracerFactory = streamTracerFactory;
this.status = checkNotNull(status, "status");
this.drop = drop;
this.authorityOverrideHostname = authorityOverrideHostname;
}

/**
Expand Down Expand Up @@ -639,6 +651,18 @@ public static PickResult withSubchannel(
false);
}

/**
* Same as {@code withSubchannel(subchannel, streamTracerFactory)} but with an authority name
* to override in the host header.
*/
public static PickResult withSubchannel(
Subchannel subchannel, @Nullable ClientStreamTracer.Factory streamTracerFactory,
@Nullable String authorityOverrideHostname) {
return new PickResult(
checkNotNull(subchannel, "subchannel"), streamTracerFactory, Status.OK,
false, authorityOverrideHostname);
}

/**
* Equivalent to {@code withSubchannel(subchannel, null)}.
*
Expand Down Expand Up @@ -682,6 +706,12 @@ public static PickResult withNoResult() {
return NO_RESULT;
}

/** Returns the authority override hostname if any. */
@Nullable
public String getAuthorityOverrideHostname() {
return authorityOverrideHostname;
}

/**
* The Subchannel if this result was created by {@link #withSubchannel withSubchannel()}, or
* null otherwise.
Expand Down Expand Up @@ -736,6 +766,7 @@ public String toString() {
.add("streamTracerFactory", streamTracerFactory)
.add("status", status)
.add("drop", drop)
.add("authority-override", authorityOverrideHostname)
.toString();
}

Expand Down
12 changes: 11 additions & 1 deletion core/src/main/java/io/grpc/internal/DelayedClientTransport.java
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,17 @@ public final ClientStream newStream(
}
if (state.lastPicker != null) {
PickResult pickResult = state.lastPicker.pickSubchannel(args);
callOptions = args.getCallOptions();
// User code provided authority takes precedence over the LB provided one.
if (callOptions.getAuthority() == null
&& pickResult.getAuthorityOverrideHostname() != null) {
callOptions = callOptions.withAuthority(pickResult.getAuthorityOverrideHostname());
}
ClientTransport transport = GrpcUtil.getTransportFromPickResult(pickResult,
callOptions.isWaitForReady());
if (transport != null) {
return transport.newStream(
args.getMethodDescriptor(), args.getHeaders(), args.getCallOptions(),
args.getMethodDescriptor(), args.getHeaders(), callOptions,
tracers);
}
}
Expand Down Expand Up @@ -281,6 +287,10 @@ final void reprocess(@Nullable SubchannelPicker picker) {
for (final PendingStream stream : toProcess) {
PickResult pickResult = picker.pickSubchannel(stream.args);
CallOptions callOptions = stream.args.getCallOptions();
// User code provided authority takes precedence over the LB provided one.
if (callOptions.getAuthority() == null && pickResult.getAuthorityOverrideHostname() != null) {
stream.setAuthority(pickResult.getAuthorityOverrideHostname());
}
final ClientTransport transport = GrpcUtil.getTransportFromPickResult(pickResult,
callOptions.isWaitForReady());
if (transport != null) {
Expand Down
1 change: 0 additions & 1 deletion core/src/main/java/io/grpc/internal/DelayedStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ private void delayOrExecute(Runnable runnable) {

@Override
public void setAuthority(final String authority) {
checkState(listener == null, "May only be called before start");
checkNotNull(authority, "authority");
preStartPendingCalls.add(new Runnable() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,43 @@ public void uncaughtException(Thread t, Throwable e) {
verify(transportListener).transportTerminated();
}

@Test
public void reprocess_authorityOverridePresentInCallOptions_authorityOverrideFromLbIsIgnored() {
DelayedStream delayedStream = (DelayedStream) delayedTransport.newStream(
method, headers, callOptions, tracers);
delayedStream.start(mock(ClientStreamListener.class));
SubchannelPicker picker = mock(SubchannelPicker.class);
PickResult pickResult = PickResult.withSubchannel(
mockSubchannel, null, "authority-override-hostname-from-lb");
when(picker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn(pickResult);

delayedTransport.reprocess(picker);
fakeExecutor.runDueTasks();

verify(mockRealStream, never()).setAuthority("authority-override-hostname-from-lb");
}

@Test
public void
reprocess_authorityOverrideNotInCallOptions_authorityOverrideFromLbIsSetIntoStream() {
DelayedStream delayedStream = (DelayedStream) delayedTransport.newStream(
method, headers, callOptions.withAuthority(null), tracers);
delayedStream.start(mock(ClientStreamListener.class));
SubchannelPicker picker = mock(SubchannelPicker.class);
PickResult pickResult = PickResult.withSubchannel(
mockSubchannel, null, "authority-override-hostname-from-lb");
when(picker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn(pickResult);
when(mockRealTransport.newStream(
same(method), same(headers), any(CallOptions.class),
ArgumentMatchers.any()))
.thenReturn(mockRealStream);

delayedTransport.reprocess(picker);
fakeExecutor.runDueTasks();

verify(mockRealStream).setAuthority("authority-override-hostname-from-lb");
}

@Test
public void reprocess_NoPendingStream() {
SubchannelPicker picker = mock(SubchannelPicker.class);
Expand All @@ -525,6 +562,55 @@ public void reprocess_NoPendingStream() {
assertSame(mockRealStream, stream);
}

@Test
public void newStream_assignsTransport_authorityFromCallOptionsSupersedesAuthorityFromLB() {
SubchannelPicker picker = mock(SubchannelPicker.class);
AbstractSubchannel subchannel = mock(AbstractSubchannel.class);
when(subchannel.getInternalSubchannel()).thenReturn(mockInternalSubchannel);
PickResult pickResult = PickResult.withSubchannel(
subchannel, null, "authority-override-hostname-from-lb");
when(picker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn(pickResult);
ArgumentCaptor<CallOptions> callOptionsArgumentCaptor =
ArgumentCaptor.forClass(CallOptions.class);
when(mockRealTransport.newStream(
any(MethodDescriptor.class), any(Metadata.class), callOptionsArgumentCaptor.capture(),
ArgumentMatchers.<ClientStreamTracer[]>any()))
.thenReturn(mockRealStream);
delayedTransport.reprocess(picker);
verifyNoMoreInteractions(picker);
verifyNoMoreInteractions(transportListener);

CallOptions callOptions =
CallOptions.DEFAULT.withAuthority("authority-override-hosstname-from-calloptions");
delayedTransport.newStream(method, headers, callOptions, tracers);
assertThat(callOptionsArgumentCaptor.getValue().getAuthority()).isEqualTo(
"authority-override-hosstname-from-calloptions");
}

@Test
public void newStream_assignsTransport_authorityFromLB() {
SubchannelPicker picker = mock(SubchannelPicker.class);
AbstractSubchannel subchannel = mock(AbstractSubchannel.class);
when(subchannel.getInternalSubchannel()).thenReturn(mockInternalSubchannel);
PickResult pickResult = PickResult.withSubchannel(
subchannel, null, "authority-override-hostname-from-lb");
when(picker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn(pickResult);
ArgumentCaptor<CallOptions> callOptionsArgumentCaptor =
ArgumentCaptor.forClass(CallOptions.class);
when(mockRealTransport.newStream(
any(MethodDescriptor.class), any(Metadata.class), callOptionsArgumentCaptor.capture(),
ArgumentMatchers.<ClientStreamTracer[]>any()))
.thenReturn(mockRealStream);
delayedTransport.reprocess(picker);
verifyNoMoreInteractions(picker);
verifyNoMoreInteractions(transportListener);

CallOptions callOptions = CallOptions.DEFAULT;
delayedTransport.newStream(method, headers, callOptions, tracers);
assertThat(callOptionsArgumentCaptor.getValue().getAuthority()).isEqualTo(
"authority-override-hostname-from-lb");
}

@Test
public void reprocess_newStreamRacesWithReprocess() throws Exception {
final CyclicBarrier barrier = new CyclicBarrier(2);
Expand Down
6 changes: 0 additions & 6 deletions core/src/test/java/io/grpc/internal/DelayedStreamTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,6 @@ public void setStream_setAuthority() {
inOrder.verify(realStream).start(any(ClientStreamListener.class));
}

@Test(expected = IllegalStateException.class)
public void setAuthority_afterStart() {
stream.start(listener);
stream.setAuthority("notgonnawork");
}

@Test(expected = IllegalStateException.class)
public void start_afterStart() {
stream.start(listener);
Expand Down
27 changes: 21 additions & 6 deletions xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.internal.ForwardingClientStreamTracer;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.ObjectPool;
import io.grpc.services.MetricReport;
import io.grpc.util.ForwardingLoadBalancerHelper;
Expand Down Expand Up @@ -231,10 +232,16 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) {
args.getAddresses().get(0).getAttributes());
AtomicReference<ClusterLocality> localityAtomicReference = new AtomicReference<>(
clusterLocality);
Attributes attrs = args.getAttributes().toBuilder()
.set(ATTR_CLUSTER_LOCALITY, localityAtomicReference)
.build();
args = args.toBuilder().setAddresses(addresses).setAttributes(attrs).build();
Attributes.Builder attrsBuilder = args.getAttributes().toBuilder()
.set(ATTR_CLUSTER_LOCALITY, localityAtomicReference);
if (GrpcUtil.getFlag("GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE", false)) {
String hostname = args.getAddresses().get(0).getAttributes()
.get(InternalXdsAttributes.ATTR_ADDRESS_NAME);
if (hostname != null) {
attrsBuilder.set(InternalXdsAttributes.ATTR_ADDRESS_NAME, hostname);
}
}
args = args.toBuilder().setAddresses(addresses).setAttributes(attrsBuilder.build()).build();
final Subchannel subchannel = delegate().createSubchannel(args);

return new ForwardingSubchannel() {
Expand Down Expand Up @@ -389,7 +396,7 @@ public PickResult pickSubchannel(PickSubchannelArgs args) {
Status.UNAVAILABLE.withDescription("Dropped: " + dropOverload.category()));
}
}
final PickResult result = delegate.pickSubchannel(args);
PickResult result = delegate.pickSubchannel(args);
if (result.getStatus().isOk() && result.getSubchannel() != null) {
if (enableCircuitBreaking) {
if (inFlights.get() >= maxConcurrentRequests) {
Expand All @@ -415,9 +422,17 @@ public PickResult pickSubchannel(PickSubchannelArgs args) {
stats, inFlights, result.getStreamTracerFactory());
ClientStreamTracer.Factory orcaTracerFactory = OrcaPerRequestUtil.getInstance()
.newOrcaClientStreamTracerFactory(tracerFactory, new OrcaPerRpcListener(stats));
return PickResult.withSubchannel(result.getSubchannel(), orcaTracerFactory);
result = PickResult.withSubchannel(result.getSubchannel(),
orcaTracerFactory);
}
}
if (args.getCallOptions().getOption(XdsNameResolver.AUTO_HOST_REWRITE_KEY) != null
&& args.getCallOptions().getOption(XdsNameResolver.AUTO_HOST_REWRITE_KEY)) {
result = PickResult.withSubchannel(result.getSubchannel(),
result.getStreamTracerFactory(),
result.getSubchannel().getAttributes().get(
InternalXdsAttributes.ATTR_ADDRESS_NAME));
}
}
return result;
}
Expand Down
10 changes: 9 additions & 1 deletion xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ public void run() {
.set(InternalXdsAttributes.ATTR_LOCALITY_WEIGHT,
localityLbInfo.localityWeight())
.set(InternalXdsAttributes.ATTR_SERVER_WEIGHT, weight)
.set(InternalXdsAttributes.ATTR_ADDRESS_NAME, endpoint.hostname())
.build();
EquivalentAddressGroup eag = new EquivalentAddressGroup(
endpoint.eag().getAddresses(), attr);
Expand Down Expand Up @@ -567,7 +568,7 @@ void start() {
handleEndpointResolutionError();
return;
}
resolver.start(new NameResolverListener());
resolver.start(new NameResolverListener(dnsHostName));
}

void refresh() {
Expand Down Expand Up @@ -606,6 +607,12 @@ public void run() {
}

private class NameResolverListener extends NameResolver.Listener2 {
private final String dnsHostName;

NameResolverListener(String dnsHostName) {
this.dnsHostName = dnsHostName;
}

@Override
public void onResult(final ResolutionResult resolutionResult) {
class NameResolved implements Runnable {
Expand All @@ -625,6 +632,7 @@ public void run() {
Attributes attr = eag.getAttributes().toBuilder()
.set(InternalXdsAttributes.ATTR_LOCALITY, LOGICAL_DNS_CLUSTER_LOCALITY)
.set(InternalXdsAttributes.ATTR_LOCALITY_NAME, localityName)
.set(InternalXdsAttributes.ATTR_ADDRESS_NAME, dnsHostName)
.build();
eag = new EquivalentAddressGroup(eag.getAddresses(), attr);
eag = AddressFilter.setPathFilter(eag, Arrays.asList(priorityName, localityName));
Expand Down
10 changes: 6 additions & 4 deletions xds/src/main/java/io/grpc/xds/Endpoints.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,19 @@ abstract static class LbEndpoint {
// Whether the endpoint is healthy.
abstract boolean isHealthy();

abstract String hostname();

static LbEndpoint create(EquivalentAddressGroup eag, int loadBalancingWeight,
boolean isHealthy) {
return new AutoValue_Endpoints_LbEndpoint(eag, loadBalancingWeight, isHealthy);
boolean isHealthy, String hostname) {
return new AutoValue_Endpoints_LbEndpoint(eag, loadBalancingWeight, isHealthy, hostname);
}

// Only for testing.
@VisibleForTesting
static LbEndpoint create(
String address, int port, int loadBalancingWeight, boolean isHealthy) {
String address, int port, int loadBalancingWeight, boolean isHealthy, String hostname) {
return LbEndpoint.create(new EquivalentAddressGroup(new InetSocketAddress(address, port)),
loadBalancingWeight, isHealthy);
loadBalancingWeight, isHealthy, hostname);
}
}

Expand Down
5 changes: 5 additions & 0 deletions xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ public final class InternalXdsAttributes {
static final Attributes.Key<Long> ATTR_SERVER_WEIGHT =
Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.serverWeight");

/** Name associated with individual address, if available (e.g., DNS name). */
@EquivalentAddressGroup.Attr
static final Attributes.Key<String> ATTR_ADDRESS_NAME =
Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.addressName");

/**
* Filter chain match for network filters.
*/
Expand Down
Loading