60
60
import io .netty .channel .ChannelFutureListener ;
61
61
import io .netty .channel .ChannelHandlerContext ;
62
62
import io .netty .channel .ChannelPromise ;
63
+ import io .netty .handler .codec .http2 .DecoratingHttp2ConnectionEncoder ;
63
64
import io .netty .handler .codec .http2 .DecoratingHttp2FrameWriter ;
64
65
import io .netty .handler .codec .http2 .DefaultHttp2Connection ;
65
66
import io .netty .handler .codec .http2 .DefaultHttp2ConnectionDecoder ;
83
84
import io .netty .handler .codec .http2 .Http2Headers ;
84
85
import io .netty .handler .codec .http2 .Http2HeadersDecoder ;
85
86
import io .netty .handler .codec .http2 .Http2InboundFrameLogger ;
87
+ import io .netty .handler .codec .http2 .Http2LifecycleManager ;
86
88
import io .netty .handler .codec .http2 .Http2OutboundFrameLogger ;
87
89
import io .netty .handler .codec .http2 .Http2Settings ;
88
90
import io .netty .handler .codec .http2 .Http2Stream ;
@@ -125,13 +127,11 @@ class NettyServerHandler extends AbstractNettyHandler {
125
127
private final long keepAliveTimeoutInNanos ;
126
128
private final long maxConnectionAgeInNanos ;
127
129
private final long maxConnectionAgeGraceInNanos ;
128
- private final int maxRstCount ;
129
- private final long maxRstPeriodNanos ;
130
+ private final RstStreamCounter rstStreamCounter ;
130
131
private final List <? extends ServerStreamTracer .Factory > streamTracerFactories ;
131
132
private final TransportTracer transportTracer ;
132
133
private final KeepAliveEnforcer keepAliveEnforcer ;
133
134
private final Attributes eagAttributes ;
134
- private final Ticker ticker ;
135
135
/** Incomplete attributes produced by negotiator. */
136
136
private Attributes negotiationAttributes ;
137
137
private InternalChannelz .Security securityInfo ;
@@ -149,8 +149,6 @@ class NettyServerHandler extends AbstractNettyHandler {
149
149
private ScheduledFuture <?> maxConnectionAgeMonitor ;
150
150
@ CheckForNull
151
151
private GracefulShutdown gracefulShutdown ;
152
- private int rstCount ;
153
- private long lastRstNanoTime ;
154
152
155
153
static NettyServerHandler newHandler (
156
154
ServerTransportListener transportListener ,
@@ -251,13 +249,20 @@ static NettyServerHandler newHandler(
251
249
final KeepAliveEnforcer keepAliveEnforcer = new KeepAliveEnforcer (
252
250
permitKeepAliveWithoutCalls , permitKeepAliveTimeInNanos , TimeUnit .NANOSECONDS );
253
251
252
+ if (ticker == null ) {
253
+ ticker = Ticker .systemTicker ();
254
+ }
255
+
256
+ RstStreamCounter rstStreamCounter
257
+ = new RstStreamCounter (maxRstCount , maxRstPeriodNanos , ticker );
254
258
// Create the local flow controller configured to auto-refill the connection window.
255
259
connection .local ().flowController (
256
260
new DefaultHttp2LocalFlowController (connection , DEFAULT_WINDOW_UPDATE_RATIO , true ));
257
261
frameWriter = new WriteMonitoringFrameWriter (frameWriter , keepAliveEnforcer );
258
262
Http2ConnectionEncoder encoder =
259
263
new DefaultHttp2ConnectionEncoder (connection , frameWriter );
260
264
encoder = new Http2ControlFrameLimitEncoder (encoder , 10000 );
265
+ encoder = new Http2RstCounterEncoder (encoder , rstStreamCounter );
261
266
Http2ConnectionDecoder decoder = new DefaultHttp2ConnectionDecoder (connection , encoder ,
262
267
frameReader );
263
268
@@ -266,10 +271,6 @@ static NettyServerHandler newHandler(
266
271
settings .maxConcurrentStreams (maxStreams );
267
272
settings .maxHeaderListSize (maxHeaderListSize );
268
273
269
- if (ticker == null ) {
270
- ticker = Ticker .systemTicker ();
271
- }
272
-
273
274
return new NettyServerHandler (
274
275
channelUnused ,
275
276
connection ,
@@ -286,8 +287,7 @@ static NettyServerHandler newHandler(
286
287
maxConnectionAgeInNanos , maxConnectionAgeGraceInNanos ,
287
288
keepAliveEnforcer ,
288
289
autoFlowControl ,
289
- maxRstCount ,
290
- maxRstPeriodNanos ,
290
+ rstStreamCounter ,
291
291
eagAttributes , ticker );
292
292
}
293
293
@@ -310,8 +310,7 @@ private NettyServerHandler(
310
310
long maxConnectionAgeGraceInNanos ,
311
311
final KeepAliveEnforcer keepAliveEnforcer ,
312
312
boolean autoFlowControl ,
313
- int maxRstCount ,
314
- long maxRstPeriodNanos ,
313
+ RstStreamCounter rstStreamCounter ,
315
314
Attributes eagAttributes ,
316
315
Ticker ticker ) {
317
316
super (
@@ -363,12 +362,9 @@ public void onStreamClosed(Http2Stream stream) {
363
362
this .maxConnectionAgeInNanos = maxConnectionAgeInNanos ;
364
363
this .maxConnectionAgeGraceInNanos = maxConnectionAgeGraceInNanos ;
365
364
this .keepAliveEnforcer = checkNotNull (keepAliveEnforcer , "keepAliveEnforcer" );
366
- this .maxRstCount = maxRstCount ;
367
- this .maxRstPeriodNanos = maxRstPeriodNanos ;
365
+ this .rstStreamCounter = rstStreamCounter ;
368
366
this .eagAttributes = checkNotNull (eagAttributes , "eagAttributes" );
369
- this .ticker = checkNotNull (ticker , "ticker" );
370
367
371
- this .lastRstNanoTime = ticker .read ();
372
368
streamKey = encoder .connection ().newKey ();
373
369
this .transportListener = checkNotNull (transportListener , "transportListener" );
374
370
this .streamTracerFactories = checkNotNull (streamTracerFactories , "streamTracerFactories" );
@@ -575,24 +571,9 @@ private void onDataRead(int streamId, ByteBuf data, int padding, boolean endOfSt
575
571
}
576
572
577
573
private void onRstStreamRead (int streamId , long errorCode ) throws Http2Exception {
578
- if (maxRstCount > 0 ) {
579
- long now = ticker .read ();
580
- if (now - lastRstNanoTime > maxRstPeriodNanos ) {
581
- lastRstNanoTime = now ;
582
- rstCount = 1 ;
583
- } else {
584
- rstCount ++;
585
- if (rstCount > maxRstCount ) {
586
- throw new Http2Exception (Http2Error .ENHANCE_YOUR_CALM , "too_many_rststreams" ) {
587
- @ SuppressWarnings ("UnsynchronizedOverridesSynchronized" ) // No memory accesses
588
- @ Override
589
- public Throwable fillInStackTrace () {
590
- // Avoid the CPU cycles, since the resets may be a CPU consumption attack
591
- return this ;
592
- }
593
- };
594
- }
595
- }
574
+ Http2Exception tooManyRstStream = rstStreamCounter .countRstStream ();
575
+ if (tooManyRstStream != null ) {
576
+ throw tooManyRstStream ;
596
577
}
597
578
598
579
try {
@@ -1180,6 +1161,81 @@ public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2
1180
1161
}
1181
1162
}
1182
1163
1164
+ private static final class Http2RstCounterEncoder extends DecoratingHttp2ConnectionEncoder {
1165
+ private final RstStreamCounter rstStreamCounter ;
1166
+ private Http2LifecycleManager lifecycleManager ;
1167
+
1168
+ Http2RstCounterEncoder (Http2ConnectionEncoder encoder , RstStreamCounter rstStreamCounter ) {
1169
+ super (encoder );
1170
+ this .rstStreamCounter = rstStreamCounter ;
1171
+ }
1172
+
1173
+ @ Override
1174
+ public void lifecycleManager (Http2LifecycleManager lifecycleManager ) {
1175
+ this .lifecycleManager = lifecycleManager ;
1176
+ super .lifecycleManager (lifecycleManager );
1177
+ }
1178
+
1179
+ @ Override
1180
+ public ChannelFuture writeRstStream (
1181
+ ChannelHandlerContext ctx , int streamId , long errorCode , ChannelPromise promise ) {
1182
+ ChannelFuture future = super .writeRstStream (ctx , streamId , errorCode , promise );
1183
+ // We want to count "induced" RST_STREAM, where the server sent a reset because of a malformed
1184
+ // frame.
1185
+ boolean normalRst
1186
+ = errorCode == Http2Error .NO_ERROR .code () || errorCode == Http2Error .CANCEL .code ();
1187
+ if (!normalRst ) {
1188
+ Http2Exception tooManyRstStream = rstStreamCounter .countRstStream ();
1189
+ if (tooManyRstStream != null ) {
1190
+ lifecycleManager .onError (ctx , true , tooManyRstStream );
1191
+ ctx .close ();
1192
+ }
1193
+ }
1194
+ return future ;
1195
+ }
1196
+ }
1197
+
1198
+ private static final class RstStreamCounter {
1199
+ private final int maxRstCount ;
1200
+ private final long maxRstPeriodNanos ;
1201
+ private final Ticker ticker ;
1202
+ private int rstCount ;
1203
+ private long lastRstNanoTime ;
1204
+
1205
+ RstStreamCounter (int maxRstCount , long maxRstPeriodNanos , Ticker ticker ) {
1206
+ checkArgument (maxRstCount >= 0 , "maxRstCount must be non-negative: %s" , maxRstCount );
1207
+ this .maxRstCount = maxRstCount ;
1208
+ this .maxRstPeriodNanos = maxRstPeriodNanos ;
1209
+ this .ticker = checkNotNull (ticker , "ticker" );
1210
+ this .lastRstNanoTime = ticker .read ();
1211
+ }
1212
+
1213
+ /** Returns non-{@code null} when the connection should be killed by the caller. */
1214
+ private Http2Exception countRstStream () {
1215
+ if (maxRstCount == 0 ) {
1216
+ return null ;
1217
+ }
1218
+ long now = ticker .read ();
1219
+ if (now - lastRstNanoTime > maxRstPeriodNanos ) {
1220
+ lastRstNanoTime = now ;
1221
+ rstCount = 1 ;
1222
+ } else {
1223
+ rstCount ++;
1224
+ if (rstCount > maxRstCount ) {
1225
+ return new Http2Exception (Http2Error .ENHANCE_YOUR_CALM , "too_many_rststreams" ) {
1226
+ @ SuppressWarnings ("UnsynchronizedOverridesSynchronized" ) // No memory accesses
1227
+ @ Override
1228
+ public Throwable fillInStackTrace () {
1229
+ // Avoid the CPU cycles, since the resets may be a CPU consumption attack
1230
+ return this ;
1231
+ }
1232
+ };
1233
+ }
1234
+ }
1235
+ return null ;
1236
+ }
1237
+ }
1238
+
1183
1239
private static class ServerChannelLogger extends ChannelLogger {
1184
1240
private static final Logger log = Logger .getLogger (ChannelLogger .class .getName ());
1185
1241
0 commit comments