Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 91 additions & 35 deletions netty/src/main/java/io/grpc/netty/NettyServerHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http2.DecoratingHttp2ConnectionEncoder;
import io.netty.handler.codec.http2.DecoratingHttp2FrameWriter;
import io.netty.handler.codec.http2.DefaultHttp2Connection;
import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder;
Expand All @@ -83,6 +84,7 @@
import io.netty.handler.codec.http2.Http2Headers;
import io.netty.handler.codec.http2.Http2HeadersDecoder;
import io.netty.handler.codec.http2.Http2InboundFrameLogger;
import io.netty.handler.codec.http2.Http2LifecycleManager;
import io.netty.handler.codec.http2.Http2OutboundFrameLogger;
import io.netty.handler.codec.http2.Http2Settings;
import io.netty.handler.codec.http2.Http2Stream;
Expand Down Expand Up @@ -125,13 +127,11 @@ class NettyServerHandler extends AbstractNettyHandler {
private final long keepAliveTimeoutInNanos;
private final long maxConnectionAgeInNanos;
private final long maxConnectionAgeGraceInNanos;
private final int maxRstCount;
private final long maxRstPeriodNanos;
private final RstStreamCounter rstStreamCounter;
private final List<? extends ServerStreamTracer.Factory> streamTracerFactories;
private final TransportTracer transportTracer;
private final KeepAliveEnforcer keepAliveEnforcer;
private final Attributes eagAttributes;
private final Ticker ticker;
/** Incomplete attributes produced by negotiator. */
private Attributes negotiationAttributes;
private InternalChannelz.Security securityInfo;
Expand All @@ -149,8 +149,6 @@ class NettyServerHandler extends AbstractNettyHandler {
private ScheduledFuture<?> maxConnectionAgeMonitor;
@CheckForNull
private GracefulShutdown gracefulShutdown;
private int rstCount;
private long lastRstNanoTime;

static NettyServerHandler newHandler(
ServerTransportListener transportListener,
Expand Down Expand Up @@ -251,13 +249,20 @@ static NettyServerHandler newHandler(
final KeepAliveEnforcer keepAliveEnforcer = new KeepAliveEnforcer(
permitKeepAliveWithoutCalls, permitKeepAliveTimeInNanos, TimeUnit.NANOSECONDS);

if (ticker == null) {
ticker = Ticker.systemTicker();
}

RstStreamCounter rstStreamCounter
= new RstStreamCounter(maxRstCount, maxRstPeriodNanos, ticker);
// Create the local flow controller configured to auto-refill the connection window.
connection.local().flowController(
new DefaultHttp2LocalFlowController(connection, DEFAULT_WINDOW_UPDATE_RATIO, true));
frameWriter = new WriteMonitoringFrameWriter(frameWriter, keepAliveEnforcer);
Http2ConnectionEncoder encoder =
new DefaultHttp2ConnectionEncoder(connection, frameWriter);
encoder = new Http2ControlFrameLimitEncoder(encoder, 10000);
encoder = new Http2RstCounterEncoder(encoder, rstStreamCounter);
Http2ConnectionDecoder decoder = new DefaultHttp2ConnectionDecoder(connection, encoder,
frameReader);

Expand All @@ -266,10 +271,6 @@ static NettyServerHandler newHandler(
settings.maxConcurrentStreams(maxStreams);
settings.maxHeaderListSize(maxHeaderListSize);

if (ticker == null) {
ticker = Ticker.systemTicker();
}

return new NettyServerHandler(
channelUnused,
connection,
Expand All @@ -286,8 +287,7 @@ static NettyServerHandler newHandler(
maxConnectionAgeInNanos, maxConnectionAgeGraceInNanos,
keepAliveEnforcer,
autoFlowControl,
maxRstCount,
maxRstPeriodNanos,
rstStreamCounter,
eagAttributes, ticker);
}

Expand All @@ -310,8 +310,7 @@ private NettyServerHandler(
long maxConnectionAgeGraceInNanos,
final KeepAliveEnforcer keepAliveEnforcer,
boolean autoFlowControl,
int maxRstCount,
long maxRstPeriodNanos,
RstStreamCounter rstStreamCounter,
Attributes eagAttributes,
Ticker ticker) {
super(
Expand Down Expand Up @@ -363,12 +362,9 @@ public void onStreamClosed(Http2Stream stream) {
this.maxConnectionAgeInNanos = maxConnectionAgeInNanos;
this.maxConnectionAgeGraceInNanos = maxConnectionAgeGraceInNanos;
this.keepAliveEnforcer = checkNotNull(keepAliveEnforcer, "keepAliveEnforcer");
this.maxRstCount = maxRstCount;
this.maxRstPeriodNanos = maxRstPeriodNanos;
this.rstStreamCounter = rstStreamCounter;
this.eagAttributes = checkNotNull(eagAttributes, "eagAttributes");
this.ticker = checkNotNull(ticker, "ticker");

this.lastRstNanoTime = ticker.read();
streamKey = encoder.connection().newKey();
this.transportListener = checkNotNull(transportListener, "transportListener");
this.streamTracerFactories = checkNotNull(streamTracerFactories, "streamTracerFactories");
Expand Down Expand Up @@ -575,24 +571,9 @@ private void onDataRead(int streamId, ByteBuf data, int padding, boolean endOfSt
}

private void onRstStreamRead(int streamId, long errorCode) throws Http2Exception {
if (maxRstCount > 0) {
long now = ticker.read();
if (now - lastRstNanoTime > maxRstPeriodNanos) {
lastRstNanoTime = now;
rstCount = 1;
} else {
rstCount++;
if (rstCount > maxRstCount) {
throw new Http2Exception(Http2Error.ENHANCE_YOUR_CALM, "too_many_rststreams") {
@SuppressWarnings("UnsynchronizedOverridesSynchronized") // No memory accesses
@Override
public Throwable fillInStackTrace() {
// Avoid the CPU cycles, since the resets may be a CPU consumption attack
return this;
}
};
}
}
Http2Exception tooManyRstStream = rstStreamCounter.countRstStream();
if (tooManyRstStream != null) {
throw tooManyRstStream;
}

try {
Expand Down Expand Up @@ -1180,6 +1161,81 @@ public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2
}
}

private static final class Http2RstCounterEncoder extends DecoratingHttp2ConnectionEncoder {
private final RstStreamCounter rstStreamCounter;
private Http2LifecycleManager lifecycleManager;

Http2RstCounterEncoder(Http2ConnectionEncoder encoder, RstStreamCounter rstStreamCounter) {
super(encoder);
this.rstStreamCounter = rstStreamCounter;
}

@Override
public void lifecycleManager(Http2LifecycleManager lifecycleManager) {
this.lifecycleManager = lifecycleManager;
super.lifecycleManager(lifecycleManager);
}

@Override
public ChannelFuture writeRstStream(
ChannelHandlerContext ctx, int streamId, long errorCode, ChannelPromise promise) {
ChannelFuture future = super.writeRstStream(ctx, streamId, errorCode, promise);
// We want to count "induced" RST_STREAM, where the server sent a reset because of a malformed
// frame.
boolean normalRst
= errorCode == Http2Error.NO_ERROR.code() || errorCode == Http2Error.CANCEL.code();
if (!normalRst) {
Http2Exception tooManyRstStream = rstStreamCounter.countRstStream();
if (tooManyRstStream != null) {
lifecycleManager.onError(ctx, true, tooManyRstStream);
ctx.close();
}
}
return future;
}
}

private static final class RstStreamCounter {
private final int maxRstCount;
private final long maxRstPeriodNanos;
private final Ticker ticker;
private int rstCount;
private long lastRstNanoTime;

RstStreamCounter(int maxRstCount, long maxRstPeriodNanos, Ticker ticker) {
checkArgument(maxRstCount >= 0, "maxRstCount must be non-negative: %s", maxRstCount);
this.maxRstCount = maxRstCount;
this.maxRstPeriodNanos = maxRstPeriodNanos;
this.ticker = checkNotNull(ticker, "ticker");
this.lastRstNanoTime = ticker.read();
}

/** Returns non-{@code null} when the connection should be killed by the caller. */
private Http2Exception countRstStream() {
if (maxRstCount == 0) {
return null;
}
long now = ticker.read();
if (now - lastRstNanoTime > maxRstPeriodNanos) {
lastRstNanoTime = now;
rstCount = 1;
} else {
rstCount++;
if (rstCount > maxRstCount) {
return new Http2Exception(Http2Error.ENHANCE_YOUR_CALM, "too_many_rststreams") {
@SuppressWarnings("UnsynchronizedOverridesSynchronized") // No memory accesses
@Override
public Throwable fillInStackTrace() {
// Avoid the CPU cycles, since the resets may be a CPU consumption attack
return this;
}
};
}
}
return null;
}
}

private static class ServerChannelLogger extends ChannelLogger {
private static final Logger log = Logger.getLogger(ChannelLogger.class.getName());

Expand Down
44 changes: 44 additions & 0 deletions netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1304,6 +1304,8 @@ public void maxRstCount_exceedsLimit_fails() throws Exception {
}

private void rapidReset(int burstSize) throws Exception {
when(streamTracerFactory.newServerStreamTracer(anyString(), any(Metadata.class)))
.thenAnswer((args) -> new TestServerStreamTracer());
Http2Headers headers = new DefaultHttp2Headers()
.method(HTTP_METHOD)
.set(CONTENT_TYPE_HEADER, new AsciiString("application/grpc", UTF_8))
Expand All @@ -1323,6 +1325,48 @@ private void rapidReset(int burstSize) throws Exception {
}
}

@Test
public void maxRstCountSent_withinLimit_succeeds() throws Exception {
maxRstCount = 10;
maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100);
manualSetUp();
madeYouReset(maxRstCount);

assertTrue(channel().isOpen());
}

@Test
public void maxRstCountSent_exceedsLimit_fails() throws Exception {
maxRstCount = 10;
maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100);
manualSetUp();
assertThrows(ClosedChannelException.class, () -> madeYouReset(maxRstCount + 1));

assertFalse(channel().isOpen());
}

private void madeYouReset(int burstSize) throws Exception {
when(streamTracerFactory.newServerStreamTracer(anyString(), any(Metadata.class)))
.thenAnswer((args) -> new TestServerStreamTracer());
Http2Headers headers = new DefaultHttp2Headers()
.method(HTTP_METHOD)
.set(CONTENT_TYPE_HEADER, new AsciiString("application/grpc", UTF_8))
.set(TE_HEADER, TE_TRAILERS)
.path(new AsciiString("/foo/bar"));
int streamId = 1;
long rpcTimeNanos = maxRstPeriodNanos / 2 / burstSize;
for (int period = 0; period < 3; period++) {
for (int i = 0; i < burstSize; i++) {
channelRead(headersFrame(streamId, headers));
channelRead(windowUpdate(streamId, 0));
streamId += 2;
fakeClock().forwardNanos(rpcTimeNanos);
}
while (channel().readOutbound() != null) {}
fakeClock().forwardNanos(maxRstPeriodNanos - rpcTimeNanos * burstSize + 1);
}
}

private void createStream() throws Exception {
Http2Headers headers = new DefaultHttp2Headers()
.method(HTTP_METHOD)
Expand Down