diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java index 11cd177a840..50de8c7002f 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java @@ -321,11 +321,12 @@ public void transportHeadersReceived(List
headers, boolean endOfStream) * Must be called with holding the transport lock. */ @GuardedBy("lock") - public void transportDataReceived(okio.Buffer frame, boolean endOfStream) { + public void transportDataReceived(okio.Buffer frame, boolean endOfStream, int paddingLen) { // We only support 16 KiB frames, and the max permitted in HTTP/2 is 16 MiB. This is verified // in OkHttp's Http2 deframer. In addition, this code is after the data has been read. int length = (int) frame.size(); - window -= length; + window -= length + paddingLen; + processedWindow -= paddingLen; if (window < 0) { frameWriter.rstStream(id(), ErrorCode.FLOW_CONTROL_ERROR); transport.finishStream( diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java index 6eaaf832a6b..ea3bf77e990 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java @@ -1140,7 +1140,8 @@ public void run() { */ @SuppressWarnings("GuardedBy") @Override - public void data(boolean inFinished, int streamId, BufferedSource in, int length) + public void data(boolean inFinished, int streamId, BufferedSource in, int length, + int paddedLength) throws IOException { logger.logData(OkHttpFrameLogger.Direction.INBOUND, streamId, in.getBuffer(), length, inFinished); @@ -1166,12 +1167,12 @@ public void data(boolean inFinished, int streamId, BufferedSource in, int length synchronized (lock) { // TODO(b/145386688): This access should be guarded by 'stream.transportState().lock'; // instead found: 'OkHttpClientTransport.this.lock' - stream.transportState().transportDataReceived(buf, inFinished); + stream.transportState().transportDataReceived(buf, inFinished, paddedLength - length); } } // connection window update - connectionUnacknowledgedBytesRead += length; + connectionUnacknowledgedBytesRead += paddedLength; if (connectionUnacknowledgedBytesRead >= initialWindowSize * DEFAULT_WINDOW_UPDATE_RATIO) { synchronized (lock) { frameWriter.windowUpdate(0, connectionUnacknowledgedBytesRead); diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerStream.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerStream.java index 85ed916095b..51275429c93 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerStream.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerStream.java @@ -208,13 +208,15 @@ public void runOnTransportThread(final Runnable r) { * Must be called with holding the transport lock. */ @Override - public void inboundDataReceived(okio.Buffer frame, int windowConsumed, boolean endOfStream) { + public void inboundDataReceived(okio.Buffer frame, int dataLength, int paddingLength, + boolean endOfStream) { synchronized (lock) { PerfMark.event("OkHttpServerTransport$FrameHandler.data", tag); if (endOfStream) { this.receivedEndOfStream = true; } - window -= windowConsumed; + window -= dataLength + paddingLength; + processedWindow -= paddingLength; super.inboundDataReceived(new OkHttpReadableBuffer(frame), endOfStream); } } diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java index dae59f37ad2..8fb74d3f1b5 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java @@ -248,8 +248,8 @@ public void data(boolean outFinished, int streamId, Buffer source, int byteCount TimeUnit.NANOSECONDS); } - transportExecutor.execute( - new FrameHandler(variant.newReader(Okio.buffer(Okio.source(socket)), false))); + transportExecutor.execute(new FrameHandler( + variant.newReader(Okio.buffer(Okio.source(socket)), false))); } catch (Error | IOException | RuntimeException ex) { synchronized (lock) { if (!handshakeShutdown) { @@ -708,7 +708,7 @@ public void headers(boolean outFinished, return; } // Ignore the trailers, but still half-close the stream - stream.inboundDataReceived(new Buffer(), 0, true); + stream.inboundDataReceived(new Buffer(), 0, 0, true); return; } } else { @@ -799,7 +799,7 @@ public void headers(boolean outFinished, listener.streamCreated(streamForApp, method, metadata); stream.onStreamAllocated(); if (inFinished) { - stream.inboundDataReceived(new Buffer(), 0, inFinished); + stream.inboundDataReceived(new Buffer(), 0, 0, inFinished); } } } @@ -819,7 +819,8 @@ private int headerBlockSize(List
headerBlock) { * Handle an HTTP2 DATA frame. */ @Override - public void data(boolean inFinished, int streamId, BufferedSource in, int length) + public void data(boolean inFinished, int streamId, BufferedSource in, int length, + int paddedLength) throws IOException { frameLogger.logData( OkHttpFrameLogger.Direction.INBOUND, streamId, in.getBuffer(), length, inFinished); @@ -853,7 +854,7 @@ public void data(boolean inFinished, int streamId, BufferedSource in, int length "Received DATA for half-closed (remote) stream. RFC7540 section 5.1"); return; } - if (stream.inboundWindowAvailable() < length) { + if (stream.inboundWindowAvailable() < paddedLength) { in.skip(length); streamError(streamId, ErrorCode.FLOW_CONTROL_ERROR, "Received DATA size exceeded window size. RFC7540 section 6.9"); @@ -861,11 +862,11 @@ public void data(boolean inFinished, int streamId, BufferedSource in, int length } Buffer buf = new Buffer(); buf.write(in.getBuffer(), length); - stream.inboundDataReceived(buf, length, inFinished); + stream.inboundDataReceived(buf, length, paddedLength - length, inFinished); } // connection window update - connectionUnacknowledgedBytesRead += length; + connectionUnacknowledgedBytesRead += paddedLength; if (connectionUnacknowledgedBytesRead >= config.flowControlWindow * Utils.DEFAULT_WINDOW_UPDATE_RATIO) { synchronized (lock) { @@ -1064,7 +1065,7 @@ private void respondWithHttpError( } streams.put(streamId, stream); if (inFinished) { - stream.inboundDataReceived(new Buffer(), 0, true); + stream.inboundDataReceived(new Buffer(), 0, 0, true); } frameWriter.headers(streamId, headers); outboundFlow.data( @@ -1122,7 +1123,7 @@ public void onPingTimeout() { interface StreamState { /** Must be holding 'lock' when calling. */ - void inboundDataReceived(Buffer frame, int windowConsumed, boolean endOfStream); + void inboundDataReceived(Buffer frame, int dataLength, int paddingLength, boolean endOfStream); /** Must be holding 'lock' when calling. */ boolean hasReceivedEndOfStream(); @@ -1159,12 +1160,12 @@ static class Http2ErrorStreamState implements StreamState, OutboundFlowControlle @Override public void onSentBytes(int frameBytes) {} @Override public void inboundDataReceived( - Buffer frame, int windowConsumed, boolean endOfStream) { + Buffer frame, int dataLength, int paddingLength, boolean endOfStream) { synchronized (lock) { if (endOfStream) { receivedEndOfStream = true; } - window -= windowConsumed; + window -= dataLength + paddingLength; try { frame.skip(frame.size()); // Recycle segments } catch (IOException ex) { diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java index 644ba27b50f..7347399bfe5 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java @@ -291,14 +291,16 @@ public void close() throws SecurityException { final String message = "Hello Client"; Buffer buffer = createMessageFrame(message); - frameHandler().data(false, 3, buffer, (int) buffer.size()); + frameHandler().data(false, 3, buffer, (int) buffer.size(), + (int) buffer.size()); assertThat(logs).hasSize(1); log = logs.remove(0); assertThat(log.getMessage()).startsWith(Direction.INBOUND + " DATA: streamId=" + 3); assertThat(log.getLevel()).isEqualTo(Level.FINE); // At most 64 bytes of data frame will be logged. - frameHandler().data(false, 3, createMessageFrame(new String(new char[1000])), 1000); + frameHandler().data(false, 3, createMessageFrame(new String(new char[1000])), + 1000, 1000); assertThat(logs).hasSize(1); log = logs.remove(0); String data = log.getMessage(); @@ -377,7 +379,8 @@ public void maxMessageSizeShouldBeEnforced() throws Exception { // Receive the message. final String message = "Hello Client"; Buffer buffer = createMessageFrame(message); - frameHandler().data(false, 3, buffer, (int) buffer.size()); + frameHandler().data(false, 3, buffer, (int) buffer.size(), + (int) buffer.size()); listener.waitUntilStreamClosed(); assertEquals(Code.RESOURCE_EXHAUSTED, listener.status.getCode()); @@ -500,7 +503,8 @@ public void readMessages() throws Exception { assertNotNull(listener.headers); for (int i = 0; i < numMessages; i++) { Buffer buffer = createMessageFrame(message + i); - frameHandler().data(false, 3, buffer, (int) buffer.size()); + frameHandler().data(false, 3, buffer, (int) buffer.size(), + (int) buffer.size()); } frameHandler().headers(true, true, 3, 0, grpcResponseTrailers(), HeadersMode.HTTP_20_HEADERS); listener.waitUntilStreamClosed(); @@ -529,7 +533,8 @@ public void receivedHeadersForInvalidStreamShouldKillConnection() throws Excepti @Test public void receivedDataForInvalidStreamShouldKillConnection() throws Exception { initTransport(); - frameHandler().data(false, 3, createMessageFrame(new String(new char[1000])), 1000); + frameHandler().data(false, 3, createMessageFrame(new String(new char[1000])), + 1000, 1000); verify(frameWriter, timeout(TIME_OUT_MS)) .goAway(eq(0), eq(ErrorCode.PROTOCOL_ERROR), any(byte[].class)); verify(transportListener).transportShutdown(isA(Status.class)); @@ -551,7 +556,8 @@ public void invalidInboundHeadersCancelStream() throws Exception { HeadersMode.HTTP_20_HEADERS); // Now wait to receive 1000 bytes of data so we can have a better error message before // cancelling the streaam. - frameHandler().data(false, 3, createMessageFrame(new String(new char[1000])), 1000); + frameHandler().data(false, 3, + createMessageFrame(new String(new char[1000])), 1000, 1000); verify(frameWriter, timeout(TIME_OUT_MS)).rstStream(eq(3), eq(ErrorCode.CANCEL)); assertNull(listener.headers); assertEquals(Status.INTERNAL.getCode(), listener.status.getCode()); @@ -622,7 +628,8 @@ public void receiveResetNoError() throws Exception { assertContainStream(3); frameHandler().headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS); Buffer buffer = createMessageFrame("a message"); - frameHandler().data(false, 3, buffer, (int) buffer.size()); + frameHandler().data(false, 3, buffer, (int) buffer.size(), + (int) buffer.size()); frameHandler().headers(true, true, 3, 0, grpcResponseTrailers(), HeadersMode.HTTP_20_HEADERS); frameHandler().rstStream(3, ErrorCode.NO_ERROR); stream.request(1); @@ -762,15 +769,18 @@ public void windowUpdate() throws Exception { int messageLength = INITIAL_WINDOW_SIZE / 4; byte[] fakeMessage = new byte[messageLength]; + int paddingLength = 2; // Stream 1 receives a message - Buffer buffer = createMessageFrame(fakeMessage); + Buffer buffer = createMessageFrame(fakeMessage, paddingLength); int messageFrameLength = (int) buffer.size(); - frameHandler().data(false, 3, buffer, messageFrameLength); + frameHandler().data(false, 3, buffer, messageFrameLength - paddingLength, + messageFrameLength); // Stream 2 receives a message - buffer = createMessageFrame(fakeMessage); - frameHandler().data(false, 5, buffer, messageFrameLength); + buffer = createMessageFrame(fakeMessage, paddingLength); + frameHandler().data(false, 5, buffer, messageFrameLength - paddingLength, + messageFrameLength); verify(frameWriter, timeout(TIME_OUT_MS)) .windowUpdate(eq(0), eq((long) 2 * messageFrameLength)); @@ -778,17 +788,18 @@ public void windowUpdate() throws Exception { // Stream 1 receives another message buffer = createMessageFrame(fakeMessage); - frameHandler().data(false, 3, buffer, messageFrameLength); + messageFrameLength = (int) buffer.size(); + frameHandler().data(false, 3, buffer, messageFrameLength, messageFrameLength); verify(frameWriter, timeout(TIME_OUT_MS)) - .windowUpdate(eq(3), eq((long) 2 * messageFrameLength)); + .windowUpdate(eq(3), eq((long) 2 * messageFrameLength + paddingLength)); // Stream 2 receives another message buffer = createMessageFrame(fakeMessage); - frameHandler().data(false, 5, buffer, messageFrameLength); + frameHandler().data(false, 5, buffer, messageFrameLength, messageFrameLength); verify(frameWriter, timeout(TIME_OUT_MS)) - .windowUpdate(eq(5), eq((long) 2 * messageFrameLength)); + .windowUpdate(eq(5), eq((long) 2 * messageFrameLength + paddingLength)); verify(frameWriter, timeout(TIME_OUT_MS)) .windowUpdate(eq(0), eq((long) 2 * messageFrameLength)); @@ -819,7 +830,8 @@ public void windowUpdateWithInboundFlowControl() throws Exception { frameHandler().headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS); Buffer buffer = createMessageFrame(fakeMessage); long messageFrameLength = buffer.size(); - frameHandler().data(false, 3, buffer, (int) messageFrameLength); + frameHandler().data(false, 3, buffer, (int) messageFrameLength, + (int) messageFrameLength); ArgumentCaptor idCaptor = ArgumentCaptor.forClass(Integer.class); verify(frameWriter, timeout(TIME_OUT_MS)).windowUpdate( idCaptor.capture(), eq(messageFrameLength)); @@ -1123,7 +1135,8 @@ public void receiveGoAway() throws Exception { frameHandler().headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS); final String receivedMessage = "No, you are fine."; Buffer buffer = createMessageFrame(receivedMessage); - frameHandler().data(false, 3, buffer, (int) buffer.size()); + frameHandler().data(false, 3, buffer, (int) buffer.size(), + (int) buffer.size()); frameHandler().headers(true, true, 3, 0, grpcResponseTrailers(), HeadersMode.HTTP_20_HEADERS); listener1.waitUntilStreamClosed(); assertEquals(1, listener1.messages.size()); @@ -1154,12 +1167,12 @@ public void streamIdExhausted() throws Exception { assertNotNull(listener.headers); String message = "hello"; Buffer buffer = createMessageFrame(message); - frameHandler().data(false, startId, buffer, (int) buffer.size()); + frameHandler().data(false, startId, buffer, (int) buffer.size(), (int) buffer.size()); getStream(startId).cancel(Status.CANCELLED); // Receives the second message after be cancelled. buffer = createMessageFrame(message); - frameHandler().data(false, startId, buffer, (int) buffer.size()); + frameHandler().data(false, startId, buffer, (int) buffer.size(), (int) buffer.size()); listener.waitUntilStreamClosed(); // Should only have the first message delivered. @@ -1329,7 +1342,7 @@ public void receivingWindowExceeded() throws Exception { byte[] fakeMessage = new byte[messageLength]; Buffer buffer = createMessageFrame(fakeMessage); int messageFrameLength = (int) buffer.size(); - frameHandler().data(false, 3, buffer, messageFrameLength); + frameHandler().data(false, 3, buffer, messageFrameLength, messageFrameLength); listener.waitUntilStreamClosed(); assertEquals(Status.INTERNAL.getCode(), listener.status.getCode()); @@ -1392,7 +1405,8 @@ public void receiveDataWithoutHeader() throws Exception { stream.start(listener); stream.request(1); Buffer buffer = createMessageFrame(new byte[1]); - frameHandler().data(false, 3, buffer, (int) buffer.size()); + frameHandler().data(false, 3, buffer, (int) buffer.size(), + (int) buffer.size()); // Trigger the failure by a trailer. frameHandler().headers( @@ -1414,11 +1428,13 @@ public void receiveDataWithoutHeaderAndTrailer() throws Exception { stream.start(listener); stream.request(1); Buffer buffer = createMessageFrame(new byte[1]); - frameHandler().data(false, 3, buffer, (int) buffer.size()); + frameHandler().data(false, 3, buffer, (int) buffer.size(), + (int) buffer.size()); // Trigger the failure by a data frame. buffer = createMessageFrame(new byte[1]); - frameHandler().data(true, 3, buffer, (int) buffer.size()); + frameHandler().data(true, 3, buffer, (int) buffer.size(), + (int) buffer.size()); listener.waitUntilStreamClosed(); assertEquals(Status.INTERNAL.getCode(), listener.status.getCode()); @@ -1436,7 +1452,8 @@ public void receiveLongEnoughDataWithoutHeaderAndTrailer() throws Exception { stream.start(listener); stream.request(1); Buffer buffer = createMessageFrame(new byte[1000]); - frameHandler().data(false, 3, buffer, (int) buffer.size()); + frameHandler().data(false, 3, buffer, (int) buffer.size(), + (int) buffer.size()); // Once we receive enough detail, we cancel the stream. so we should have sent cancel. verify(frameWriter, timeout(TIME_OUT_MS)).rstStream(eq(3), eq(ErrorCode.CANCEL)); @@ -1459,7 +1476,8 @@ public void receiveDataForUnknownStreamUpdateConnectionWindow() throws Exception Buffer buffer = createMessageFrame( new byte[INITIAL_WINDOW_SIZE / 2 + 1]); - frameHandler().data(false, 3, buffer, (int) buffer.size()); + frameHandler().data(false, 3, buffer, (int) buffer.size(), + (int) buffer.size()); // Should still update the connection window even stream 3 is gone. verify(frameWriter, timeout(TIME_OUT_MS)).windowUpdate(0, HEADER_LENGTH + INITIAL_WINDOW_SIZE / 2 + 1); @@ -1467,7 +1485,8 @@ public void receiveDataForUnknownStreamUpdateConnectionWindow() throws Exception new byte[INITIAL_WINDOW_SIZE / 2 + 1]); // This should kill the connection, since we never created stream 5. - frameHandler().data(false, 5, buffer, (int) buffer.size()); + frameHandler().data(false, 5, buffer, (int) buffer.size(), + (int) buffer.size()); verify(frameWriter, timeout(TIME_OUT_MS)) .goAway(eq(0), eq(ErrorCode.PROTOCOL_ERROR), any(byte[].class)); verify(transportListener).transportShutdown(isA(Status.class)); @@ -2114,10 +2133,15 @@ private static Buffer createMessageFrame(String message) { } private static Buffer createMessageFrame(byte[] message) { + return createMessageFrame(message,0); + } + + private static Buffer createMessageFrame(byte[] message, int paddingLength) { Buffer buffer = new Buffer(); buffer.writeByte(0 /* UNCOMPRESSED */); buffer.writeInt(message.length); buffer.write(message); + buffer.write(new byte[paddingLength]); return buffer; } diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java index 455908816a8..5ed8514b85b 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java @@ -60,6 +60,7 @@ import java.net.ServerSocket; import java.net.Socket; import java.util.ArrayDeque; +import java.util.ArrayList; import java.util.Arrays; import java.util.Deque; import java.util.List; @@ -70,6 +71,7 @@ import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import okio.Buffer; +import okio.BufferedSink; import okio.BufferedSource; import okio.ByteString; import okio.Okio; @@ -90,15 +92,19 @@ public class OkHttpServerTransportTest { private static final int TIME_OUT_MS = 2000; private static final int INITIAL_WINDOW_SIZE = 65535; private static final long MAX_CONNECTION_IDLE = TimeUnit.SECONDS.toNanos(1); - + private static final byte FLAG_NONE = 0x0; + private static final byte FLAG_PADDED = 0x8; + private static final byte FLAG_END_STREAM = 0x1; + private static final byte TYPE_DATA = 0x0; private MockServerTransportListener mockTransportListener = new MockServerTransportListener(); private ServerTransportListener transportListener = mock(ServerTransportListener.class, delegatesTo(mockTransportListener)); private OkHttpServerTransport serverTransport; private final ExecutorService threadPool = Executors.newCachedThreadPool(); private final SocketPair socketPair = SocketPair.create(threadPool); - private final FrameWriter clientFrameWriter - = new Http2().newWriter(Okio.buffer(Okio.sink(socketPair.getClientOutputStream())), true); + private final BufferedSink clientWriterSink = Okio.buffer( + Okio.sink(socketPair.getClientOutputStream())); + private final FrameWriter clientFrameWriter = new Http2().newWriter(clientWriterSink, true); private final FrameReader clientFrameReader = new Http2().newReader(Okio.buffer(Okio.source(socketPair.getClientInputStream())), true); private final FrameReader.Handler clientFramesRead = mock(FrameReader.Handler.class); @@ -135,7 +141,8 @@ public void setUp() throws Exception { Buffer buf = new Buffer(); buf.write(in.getBuffer(), length); clientDataFrames.data(outDone, streamId, buf); - })).when(clientFramesRead).data(anyBoolean(), anyInt(), any(BufferedSource.class), anyInt()); + })).when(clientFramesRead).data(anyBoolean(), anyInt(), any(BufferedSource.class), anyInt(), + anyInt()); } @After @@ -379,7 +386,8 @@ public void basicRpc_succeeds() throws Exception { Buffer responseMessageFrame = createMessageFrame("Howdy client"); assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); verify(clientFramesRead) - .data(eq(false), eq(1), any(BufferedSource.class), eq((int) responseMessageFrame.size())); + .data(eq(false), eq(1), any(BufferedSource.class), eq((int) responseMessageFrame.size()), + eq((int) responseMessageFrame.size())); verify(clientDataFrames).data(false, 1, responseMessageFrame); List
responseTrailers = Arrays.asList( @@ -440,7 +448,8 @@ public void activeRpc_delaysShutdownTermination() throws Exception { Buffer responseMessageFrame = createMessageFrame("Howdy client"); assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); verify(clientFramesRead) - .data(eq(false), eq(1), any(BufferedSource.class), eq((int) responseMessageFrame.size())); + .data(eq(false), eq(1), any(BufferedSource.class), eq((int) responseMessageFrame.size()), + eq((int) responseMessageFrame.size())); verify(clientDataFrames).data(false, 1, responseMessageFrame); pingPong(); assertThat(serverTransport.getActiveStreams().length).isEqualTo(1); @@ -975,7 +984,8 @@ public void httpErrorsAdhereToFlowControl() throws Exception { Buffer responseDataFrame = new Buffer().writeUtf8(errorDescription.substring(0, 1)); assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); verify(clientFramesRead).data( - eq(false), eq(1), any(BufferedSource.class), eq((int) responseDataFrame.size())); + eq(false), eq(1), any(BufferedSource.class), eq((int) responseDataFrame.size()), + eq((int) responseDataFrame.size())); verify(clientDataFrames).data(false, 1, responseDataFrame); clientFrameWriter.windowUpdate(1, 1000); @@ -984,7 +994,8 @@ public void httpErrorsAdhereToFlowControl() throws Exception { responseDataFrame = new Buffer().writeUtf8(errorDescription.substring(1)); assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); verify(clientFramesRead).data( - eq(true), eq(1), any(BufferedSource.class), eq((int) responseDataFrame.size())); + eq(true), eq(1), any(BufferedSource.class), eq((int) responseDataFrame.size()), + eq((int) responseDataFrame.size())); verify(clientDataFrames).data(true, 1, responseDataFrame); assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); @@ -993,6 +1004,71 @@ public void httpErrorsAdhereToFlowControl() throws Exception { shutdownAndTerminate(/*lastStreamId=*/ 1); } + @Test + public void windowUpdate() throws Exception { + serverBuilder.flowControlWindow(100); + initTransport(); + handshake(); + + List
headers = Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER, + new Header("some-metadata", "this could be anything")); + clientFrameWriter.headers(1, new ArrayList<>(headers)); + clientFrameWriter.headers(3, new ArrayList<>(headers)); + String message = "Hello Server Pad Me!"; // length = 20, add buffer length = 5 + writeDataDirectly(clientWriterSink, FLAG_NONE, 1, message, 0); + pingPong(); + + MockStreamListener streamListener = mockTransportListener.newStreams.pop(); + MockStreamListener streamListener2 = mockTransportListener.newStreams.pop(); + assertThat(streamListener.stream.getAuthority()).isEqualTo("example.com:80"); + assertThat(streamListener.method).isEqualTo("com.example/SimpleService.doit"); + assertThat(streamListener.headers.get( + Metadata.Key.of("Some-Metadata", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("this could be anything"); + streamListener.stream.request(1); + pingPong(); + assertThat(streamListener.messages.pop()).isEqualTo("Hello Server Pad Me!"); + streamListener.stream.writeHeaders(metadata("User-Data", "best data")); + streamListener.stream.writeMessage(new ByteArrayInputStream("Howdy client".getBytes(UTF_8))); + List
responseHeaders = Arrays.asList( + new Header(":status", "200"), + CONTENT_TYPE_HEADER, + new Header("user-data", "best data")); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead) + .headers(false, false, 1, -1, responseHeaders, HeadersMode.HTTP_20_HEADERS); + + writeDataDirectly(clientWriterSink, FLAG_PADDED, 1, message, 10); + writeDataDirectly(clientWriterSink, FLAG_PADDED | FLAG_END_STREAM, 3, message, 40); + clientFrameWriter.flush(); + + int expectedConsumed = message.length() + 5; + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).windowUpdate(0, expectedConsumed * 2 + 10); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).windowUpdate(0, expectedConsumed + 40); + streamListener.stream.request(2); + streamListener2.stream.request(1); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).windowUpdate(1, expectedConsumed * 2 + 10); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).windowUpdate(3, expectedConsumed + 40); + + writeDataDirectly(clientWriterSink, FLAG_PADDED | FLAG_END_STREAM, 1, message, 100); + clientFrameWriter.flush(); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).rstStream(eq(1), eq(ErrorCode.FLOW_CONTROL_ERROR)); + clientFrameWriter.rstStream(3, ErrorCode.CANCEL); + pingPong(); + shutdownAndTerminate(/*lastStreamId=*/ 3); + } + @Test public void dataForStream0_failsWithGoAway() throws Exception { initTransport(); @@ -1223,6 +1299,32 @@ private static Buffer createMessageFrame(String stringMessage) { return buffer; } + private void writeDataDirectly(BufferedSink sink, int flag, int streamId, String message, + int paddingLength) throws IOException { + Buffer buffer = createMessageFrame(message); + int bufferLengthWithPadding = (int) buffer.size(); + if ((flag & FLAG_PADDED) != 0) { + bufferLengthWithPadding += paddingLength; + } + writeLength(sink, bufferLengthWithPadding); + sink.writeByte(TYPE_DATA); + sink.writeByte(flag & 0xff); + sink.writeInt(streamId & 0x7fffffff); + if ((flag & FLAG_PADDED) != 0) { + sink.writeByte((short)(paddingLength - 1)); + char[] value = new char[paddingLength - 1]; + Arrays.fill(value, '!'); + buffer.write(new String(value).getBytes(UTF_8)); + } + sink.write(buffer, buffer.size()); + } + + private void writeLength(BufferedSink sink, int length) throws IOException { + sink.writeByte((length >>> 16 ) & 0xff); + sink.writeByte((length >>> 8 ) & 0xff); + sink.writeByte(length & 0xff); + } + private Metadata metadata(String... keysAndValues) { Metadata metadata = new Metadata(); assertThat(keysAndValues.length % 2).isEqualTo(0); @@ -1279,7 +1381,8 @@ private void verifyHttpError( Buffer responseDataFrame = new Buffer().writeUtf8(errorDescription); assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); verify(clientFramesRead).data( - eq(true), eq(streamId), any(BufferedSource.class), eq((int) responseDataFrame.size())); + eq(true), eq(streamId), any(BufferedSource.class), + eq((int) responseDataFrame.size()), eq((int) responseDataFrame.size())); verify(clientDataFrames).data(true, streamId, responseDataFrame); assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); diff --git a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/FrameReader.java b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/FrameReader.java index 585ccb15355..490673eff6a 100644 --- a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/FrameReader.java +++ b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/FrameReader.java @@ -32,7 +32,7 @@ public interface FrameReader extends Closeable { boolean nextFrame(Handler handler) throws IOException; interface Handler { - void data(boolean inFinished, int streamId, BufferedSource source, int length) + void data(boolean inFinished, int streamId, BufferedSource source, int length, int paddedLength) throws IOException; /** diff --git a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/Http2.java b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/Http2.java index 3a8c41c6285..0eb49b9f076 100644 --- a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/Http2.java +++ b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/Http2.java @@ -220,7 +220,7 @@ private List
readHeaderBlock(int length, short padding, byte flags, int return hpackReader.getAndResetHeaderList(); } - private void readData(Handler handler, int length, byte flags, int streamId) + private void readData(Handler handler, int paddedLength, byte flags, int streamId) throws IOException { // TODO: checkState open or half-closed (local) or raise STREAM_CLOSED boolean inFinished = (flags & FLAG_END_STREAM) != 0; @@ -230,10 +230,10 @@ private void readData(Handler handler, int length, byte flags, int streamId) } short padding = (flags & FLAG_PADDED) != 0 ? (short) (source.readByte() & 0xff) : 0; - length = lengthWithoutPadding(length, flags, padding); + int length = lengthWithoutPadding(paddedLength, flags, padding); // FIXME: pass padding length to handler because it should be included for flow control - handler.data(inFinished, streamId, source, length); + handler.data(inFinished, streamId, source, length, paddedLength); source.skip(padding); } diff --git a/okhttp/third_party/okhttp/test/java/io/grpc/okhttp/internal/framed/Http2Test.java b/okhttp/third_party/okhttp/test/java/io/grpc/okhttp/internal/framed/Http2Test.java new file mode 100644 index 00000000000..5631a18515d --- /dev/null +++ b/okhttp/third_party/okhttp/test/java/io/grpc/okhttp/internal/framed/Http2Test.java @@ -0,0 +1,98 @@ +/* + * Copyright (C) 2023 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp.internal.framed; + +import static io.grpc.okhttp.internal.framed.Http2.FLAG_NONE; +import static io.grpc.okhttp.internal.framed.Http2.FLAG_PADDED; +import static io.grpc.okhttp.internal.framed.Http2.TYPE_DATA; +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.verify; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import okio.Buffer; +import okio.BufferedSink; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class Http2Test { + @Rule + public final MockitoRule mocks = MockitoJUnit.rule(); + private FrameReader http2FrameReader; + @Mock + private FrameReader.Handler mockHandler; + private final int STREAM_ID = 6; + + @Test + public void dataFrameNoPadding() throws IOException { + Buffer bufferIn = createData(FLAG_NONE, 3239, 0 ); + http2FrameReader = new Http2.Reader(bufferIn, 100, true); + http2FrameReader.nextFrame(mockHandler); + + verify(mockHandler).data(eq(false), eq(STREAM_ID), eq(bufferIn), eq(3239), eq(3239)); + assertEquals(3239, bufferIn.size()); + } + + @Test + public void dataFrameOneLengthPadding() throws IOException { + Buffer bufferIn = createData(FLAG_PADDED, 1876, 0); + http2FrameReader = new Http2.Reader(bufferIn, 100, true); + http2FrameReader.nextFrame(mockHandler); + + verify(mockHandler).data(eq(false), eq(STREAM_ID), eq(bufferIn), eq(1875), eq(1876)); + assertEquals(1876, bufferIn.size()); + } + + @Test + public void dataFramePadding() throws IOException { + Buffer bufferIn = createData(FLAG_PADDED, 2037, 125); + http2FrameReader = new Http2.Reader(bufferIn, 100, true); + http2FrameReader.nextFrame(mockHandler); + + verify(mockHandler).data(eq(false), eq(STREAM_ID), eq(bufferIn), eq(2037 - 126), eq(2037)); + assertEquals(2037 - 125, bufferIn.size()); + } + + private Buffer createData(int flag, int length, int paddingLength) throws IOException { + Buffer sink = new Buffer(); + writeLength(sink, length); + sink.writeByte(TYPE_DATA); + sink.writeByte(flag); + sink.writeInt(STREAM_ID); + if ((flag & FLAG_PADDED) != 0) { + sink.writeByte((short)paddingLength); + } + char[] value = new char[length]; + Arrays.fill(value, '!'); + sink.write(new String(value).getBytes(StandardCharsets.UTF_8)); + return sink; + } + + private void writeLength(BufferedSink sink, int length) throws IOException { + sink.writeByte((length >>> 16 ) & 0xff); + sink.writeByte((length >>> 8 ) & 0xff); + sink.writeByte(length & 0xff); + } +}