Skip to content

Commit a5c2b1a

Browse files
authored
netty: Count sent RST_STREAMs against limit (1.75.x backport) (#12288)
Http2RstCounterEncoder has to be constructed before NettyServerHandler/Http2ConnectionHandler so it must be static. Thus the code/counters were moved into RstStreamCounter which then can be constructed earlier and shared. This depends on Netty 4.1.124 for a bug fix to actually call the encoder: netty/netty@be53dc3 Backport of #12277
1 parent 0d3e828 commit a5c2b1a

File tree

2 files changed

+135
-35
lines changed

2 files changed

+135
-35
lines changed

netty/src/main/java/io/grpc/netty/NettyServerHandler.java

Lines changed: 91 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
import io.netty.channel.ChannelFutureListener;
6161
import io.netty.channel.ChannelHandlerContext;
6262
import io.netty.channel.ChannelPromise;
63+
import io.netty.handler.codec.http2.DecoratingHttp2ConnectionEncoder;
6364
import io.netty.handler.codec.http2.DecoratingHttp2FrameWriter;
6465
import io.netty.handler.codec.http2.DefaultHttp2Connection;
6566
import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder;
@@ -83,6 +84,7 @@
8384
import io.netty.handler.codec.http2.Http2Headers;
8485
import io.netty.handler.codec.http2.Http2HeadersDecoder;
8586
import io.netty.handler.codec.http2.Http2InboundFrameLogger;
87+
import io.netty.handler.codec.http2.Http2LifecycleManager;
8688
import io.netty.handler.codec.http2.Http2OutboundFrameLogger;
8789
import io.netty.handler.codec.http2.Http2Settings;
8890
import io.netty.handler.codec.http2.Http2Stream;
@@ -125,13 +127,11 @@ class NettyServerHandler extends AbstractNettyHandler {
125127
private final long keepAliveTimeoutInNanos;
126128
private final long maxConnectionAgeInNanos;
127129
private final long maxConnectionAgeGraceInNanos;
128-
private final int maxRstCount;
129-
private final long maxRstPeriodNanos;
130+
private final RstStreamCounter rstStreamCounter;
130131
private final List<? extends ServerStreamTracer.Factory> streamTracerFactories;
131132
private final TransportTracer transportTracer;
132133
private final KeepAliveEnforcer keepAliveEnforcer;
133134
private final Attributes eagAttributes;
134-
private final Ticker ticker;
135135
/** Incomplete attributes produced by negotiator. */
136136
private Attributes negotiationAttributes;
137137
private InternalChannelz.Security securityInfo;
@@ -149,8 +149,6 @@ class NettyServerHandler extends AbstractNettyHandler {
149149
private ScheduledFuture<?> maxConnectionAgeMonitor;
150150
@CheckForNull
151151
private GracefulShutdown gracefulShutdown;
152-
private int rstCount;
153-
private long lastRstNanoTime;
154152

155153
static NettyServerHandler newHandler(
156154
ServerTransportListener transportListener,
@@ -251,13 +249,20 @@ static NettyServerHandler newHandler(
251249
final KeepAliveEnforcer keepAliveEnforcer = new KeepAliveEnforcer(
252250
permitKeepAliveWithoutCalls, permitKeepAliveTimeInNanos, TimeUnit.NANOSECONDS);
253251

252+
if (ticker == null) {
253+
ticker = Ticker.systemTicker();
254+
}
255+
256+
RstStreamCounter rstStreamCounter
257+
= new RstStreamCounter(maxRstCount, maxRstPeriodNanos, ticker);
254258
// Create the local flow controller configured to auto-refill the connection window.
255259
connection.local().flowController(
256260
new DefaultHttp2LocalFlowController(connection, DEFAULT_WINDOW_UPDATE_RATIO, true));
257261
frameWriter = new WriteMonitoringFrameWriter(frameWriter, keepAliveEnforcer);
258262
Http2ConnectionEncoder encoder =
259263
new DefaultHttp2ConnectionEncoder(connection, frameWriter);
260264
encoder = new Http2ControlFrameLimitEncoder(encoder, 10000);
265+
encoder = new Http2RstCounterEncoder(encoder, rstStreamCounter);
261266
Http2ConnectionDecoder decoder = new DefaultHttp2ConnectionDecoder(connection, encoder,
262267
frameReader);
263268

@@ -266,10 +271,6 @@ static NettyServerHandler newHandler(
266271
settings.maxConcurrentStreams(maxStreams);
267272
settings.maxHeaderListSize(maxHeaderListSize);
268273

269-
if (ticker == null) {
270-
ticker = Ticker.systemTicker();
271-
}
272-
273274
return new NettyServerHandler(
274275
channelUnused,
275276
connection,
@@ -286,8 +287,7 @@ static NettyServerHandler newHandler(
286287
maxConnectionAgeInNanos, maxConnectionAgeGraceInNanos,
287288
keepAliveEnforcer,
288289
autoFlowControl,
289-
maxRstCount,
290-
maxRstPeriodNanos,
290+
rstStreamCounter,
291291
eagAttributes, ticker);
292292
}
293293

@@ -310,8 +310,7 @@ private NettyServerHandler(
310310
long maxConnectionAgeGraceInNanos,
311311
final KeepAliveEnforcer keepAliveEnforcer,
312312
boolean autoFlowControl,
313-
int maxRstCount,
314-
long maxRstPeriodNanos,
313+
RstStreamCounter rstStreamCounter,
315314
Attributes eagAttributes,
316315
Ticker ticker) {
317316
super(
@@ -363,12 +362,9 @@ public void onStreamClosed(Http2Stream stream) {
363362
this.maxConnectionAgeInNanos = maxConnectionAgeInNanos;
364363
this.maxConnectionAgeGraceInNanos = maxConnectionAgeGraceInNanos;
365364
this.keepAliveEnforcer = checkNotNull(keepAliveEnforcer, "keepAliveEnforcer");
366-
this.maxRstCount = maxRstCount;
367-
this.maxRstPeriodNanos = maxRstPeriodNanos;
365+
this.rstStreamCounter = rstStreamCounter;
368366
this.eagAttributes = checkNotNull(eagAttributes, "eagAttributes");
369-
this.ticker = checkNotNull(ticker, "ticker");
370367

371-
this.lastRstNanoTime = ticker.read();
372368
streamKey = encoder.connection().newKey();
373369
this.transportListener = checkNotNull(transportListener, "transportListener");
374370
this.streamTracerFactories = checkNotNull(streamTracerFactories, "streamTracerFactories");
@@ -575,24 +571,9 @@ private void onDataRead(int streamId, ByteBuf data, int padding, boolean endOfSt
575571
}
576572

577573
private void onRstStreamRead(int streamId, long errorCode) throws Http2Exception {
578-
if (maxRstCount > 0) {
579-
long now = ticker.read();
580-
if (now - lastRstNanoTime > maxRstPeriodNanos) {
581-
lastRstNanoTime = now;
582-
rstCount = 1;
583-
} else {
584-
rstCount++;
585-
if (rstCount > maxRstCount) {
586-
throw new Http2Exception(Http2Error.ENHANCE_YOUR_CALM, "too_many_rststreams") {
587-
@SuppressWarnings("UnsynchronizedOverridesSynchronized") // No memory accesses
588-
@Override
589-
public Throwable fillInStackTrace() {
590-
// Avoid the CPU cycles, since the resets may be a CPU consumption attack
591-
return this;
592-
}
593-
};
594-
}
595-
}
574+
Http2Exception tooManyRstStream = rstStreamCounter.countRstStream();
575+
if (tooManyRstStream != null) {
576+
throw tooManyRstStream;
596577
}
597578

598579
try {
@@ -1180,6 +1161,81 @@ public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2
11801161
}
11811162
}
11821163

1164+
private static final class Http2RstCounterEncoder extends DecoratingHttp2ConnectionEncoder {
1165+
private final RstStreamCounter rstStreamCounter;
1166+
private Http2LifecycleManager lifecycleManager;
1167+
1168+
Http2RstCounterEncoder(Http2ConnectionEncoder encoder, RstStreamCounter rstStreamCounter) {
1169+
super(encoder);
1170+
this.rstStreamCounter = rstStreamCounter;
1171+
}
1172+
1173+
@Override
1174+
public void lifecycleManager(Http2LifecycleManager lifecycleManager) {
1175+
this.lifecycleManager = lifecycleManager;
1176+
super.lifecycleManager(lifecycleManager);
1177+
}
1178+
1179+
@Override
1180+
public ChannelFuture writeRstStream(
1181+
ChannelHandlerContext ctx, int streamId, long errorCode, ChannelPromise promise) {
1182+
ChannelFuture future = super.writeRstStream(ctx, streamId, errorCode, promise);
1183+
// We want to count "induced" RST_STREAM, where the server sent a reset because of a malformed
1184+
// frame.
1185+
boolean normalRst
1186+
= errorCode == Http2Error.NO_ERROR.code() || errorCode == Http2Error.CANCEL.code();
1187+
if (!normalRst) {
1188+
Http2Exception tooManyRstStream = rstStreamCounter.countRstStream();
1189+
if (tooManyRstStream != null) {
1190+
lifecycleManager.onError(ctx, true, tooManyRstStream);
1191+
ctx.close();
1192+
}
1193+
}
1194+
return future;
1195+
}
1196+
}
1197+
1198+
private static final class RstStreamCounter {
1199+
private final int maxRstCount;
1200+
private final long maxRstPeriodNanos;
1201+
private final Ticker ticker;
1202+
private int rstCount;
1203+
private long lastRstNanoTime;
1204+
1205+
RstStreamCounter(int maxRstCount, long maxRstPeriodNanos, Ticker ticker) {
1206+
checkArgument(maxRstCount >= 0, "maxRstCount must be non-negative: %s", maxRstCount);
1207+
this.maxRstCount = maxRstCount;
1208+
this.maxRstPeriodNanos = maxRstPeriodNanos;
1209+
this.ticker = checkNotNull(ticker, "ticker");
1210+
this.lastRstNanoTime = ticker.read();
1211+
}
1212+
1213+
/** Returns non-{@code null} when the connection should be killed by the caller. */
1214+
private Http2Exception countRstStream() {
1215+
if (maxRstCount == 0) {
1216+
return null;
1217+
}
1218+
long now = ticker.read();
1219+
if (now - lastRstNanoTime > maxRstPeriodNanos) {
1220+
lastRstNanoTime = now;
1221+
rstCount = 1;
1222+
} else {
1223+
rstCount++;
1224+
if (rstCount > maxRstCount) {
1225+
return new Http2Exception(Http2Error.ENHANCE_YOUR_CALM, "too_many_rststreams") {
1226+
@SuppressWarnings("UnsynchronizedOverridesSynchronized") // No memory accesses
1227+
@Override
1228+
public Throwable fillInStackTrace() {
1229+
// Avoid the CPU cycles, since the resets may be a CPU consumption attack
1230+
return this;
1231+
}
1232+
};
1233+
}
1234+
}
1235+
return null;
1236+
}
1237+
}
1238+
11831239
private static class ServerChannelLogger extends ChannelLogger {
11841240
private static final Logger log = Logger.getLogger(ChannelLogger.class.getName());
11851241

netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1304,6 +1304,8 @@ public void maxRstCount_exceedsLimit_fails() throws Exception {
13041304
}
13051305

13061306
private void rapidReset(int burstSize) throws Exception {
1307+
when(streamTracerFactory.newServerStreamTracer(anyString(), any(Metadata.class)))
1308+
.thenAnswer((args) -> new TestServerStreamTracer());
13071309
Http2Headers headers = new DefaultHttp2Headers()
13081310
.method(HTTP_METHOD)
13091311
.set(CONTENT_TYPE_HEADER, new AsciiString("application/grpc", UTF_8))
@@ -1323,6 +1325,48 @@ private void rapidReset(int burstSize) throws Exception {
13231325
}
13241326
}
13251327

1328+
@Test
1329+
public void maxRstCountSent_withinLimit_succeeds() throws Exception {
1330+
maxRstCount = 10;
1331+
maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100);
1332+
manualSetUp();
1333+
madeYouReset(maxRstCount);
1334+
1335+
assertTrue(channel().isOpen());
1336+
}
1337+
1338+
@Test
1339+
public void maxRstCountSent_exceedsLimit_fails() throws Exception {
1340+
maxRstCount = 10;
1341+
maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100);
1342+
manualSetUp();
1343+
assertThrows(ClosedChannelException.class, () -> madeYouReset(maxRstCount + 1));
1344+
1345+
assertFalse(channel().isOpen());
1346+
}
1347+
1348+
private void madeYouReset(int burstSize) throws Exception {
1349+
when(streamTracerFactory.newServerStreamTracer(anyString(), any(Metadata.class)))
1350+
.thenAnswer((args) -> new TestServerStreamTracer());
1351+
Http2Headers headers = new DefaultHttp2Headers()
1352+
.method(HTTP_METHOD)
1353+
.set(CONTENT_TYPE_HEADER, new AsciiString("application/grpc", UTF_8))
1354+
.set(TE_HEADER, TE_TRAILERS)
1355+
.path(new AsciiString("/foo/bar"));
1356+
int streamId = 1;
1357+
long rpcTimeNanos = maxRstPeriodNanos / 2 / burstSize;
1358+
for (int period = 0; period < 3; period++) {
1359+
for (int i = 0; i < burstSize; i++) {
1360+
channelRead(headersFrame(streamId, headers));
1361+
channelRead(windowUpdate(streamId, 0));
1362+
streamId += 2;
1363+
fakeClock().forwardNanos(rpcTimeNanos);
1364+
}
1365+
while (channel().readOutbound() != null) {}
1366+
fakeClock().forwardNanos(maxRstPeriodNanos - rpcTimeNanos * burstSize + 1);
1367+
}
1368+
}
1369+
13261370
private void createStream() throws Exception {
13271371
Http2Headers headers = new DefaultHttp2Headers()
13281372
.method(HTTP_METHOD)

0 commit comments

Comments
 (0)