diff --git a/collector/body.go b/collector/body.go index bced122..6df54d7 100644 --- a/collector/body.go +++ b/collector/body.go @@ -1,6 +1,7 @@ package collector import ( + "bytes" "errors" "io" "sync" @@ -37,6 +38,64 @@ func NewBody(rc io.ReadCloser, limit int) *Body { return b } +// PreReadBody creates a new Body wrapper that immediately pre-reads data from the body. +// This ensures body content is captured even if the underlying connection is closed early. +// It returns a Body with an io.MultiReader that combines the pre-read buffer with the original reader. +func PreReadBody(rc io.ReadCloser, limit int) *Body { + if rc == nil { + return NewBody(rc, limit) + } + + b := &Body{} + + var preReadBuffer = new(bytes.Buffer) + + // Pre-read up to limit bytes into our capture buffer + n, err := io.CopyN(preReadBuffer, rc, int64(limit)+1) // +1 to check for truncation + + truncated := n > int64(limit) + + if err == io.EOF { + // We've read everything (body was smaller than limit). + b.consumedOriginal = true + b.isFullyCaptured = !truncated + } + + multiReader := io.MultiReader(preReadBuffer, rc) + + // Wrap in a readCloser to maintain the Close capability + b.reader = &preReadBodyWrapper{ + Reader: multiReader, + closer: rc, + } + + // Set up the buffer with pre-read data but only up to the limit + preReadBytes := preReadBuffer.Bytes() + if len(preReadBytes) > limit { + preReadBytes = preReadBytes[:limit] + } + b.buffer = &LimitedBuffer{ + Buffer: bytes.NewBuffer(preReadBytes), + limit: limit, + truncated: truncated, + } + + return b +} + +// preReadBodyWrapper wraps an io.Reader with Close functionality +type preReadBodyWrapper struct { + io.Reader + closer io.Closer +} + +func (w *preReadBodyWrapper) Close() error { + if w.closer != nil { + return w.closer.Close() + } + return nil +} + func (b *Body) Read(p []byte) (n int, err error) { b.mu.Lock() defer b.mu.Unlock() @@ -52,13 +111,18 @@ func (b *Body) Read(p []byte) (n int, err error) { // Read from the original reader n, err = b.reader.Read(p) + // Only write to buffer if it's not a preReadBodyWrapper + // (preReadBodyWrapper means we already captured the data in PreReadBody) if n > 0 { - b.buffer.Write(p[:n]) + if _, isPreRead := b.reader.(*preReadBodyWrapper); !isPreRead { + _, _ = b.buffer.Write(p[:n]) + } } // If EOF, mark as fully consumed if err == io.EOF { b.consumedOriginal = true + b.isFullyCaptured = !b.buffer.IsTruncated() // Remove original body b.reader = nil @@ -68,7 +132,6 @@ func (b *Body) Read(p []byte) (n int, err error) { } // Close closes the original body and finalizes the buffer. -// This will attempt to read any unread data from the original body up to the maximum size limit. func (b *Body) Close() error { b.mu.Lock() defer b.mu.Unlock() @@ -81,39 +144,25 @@ func (b *Body) Close() error { return nil } - // Mark as closed before capturing remaining data to avoid potential recursive calls + // Mark as closed b.closed = true - // Check state to determine if we need to read more data - fullyConsumed := b.consumedOriginal - - // If the body wasn't fully read, read the rest of it into our buffer - if !fullyConsumed { - // Create a buffer for reading - buf := make([]byte, 32*1024) // 32KB chunks - - // Try to read more data - for { - var n int - var readErr error - n, readErr = b.reader.Read(buf) - - if n > 0 { - b.buffer.Write(buf[:n]) - } + // For PreReadBody cases (identified by preReadBodyWrapper), + // the data is already captured, just close + if _, isPreRead := b.reader.(*preReadBodyWrapper); isPreRead { + return b.reader.Close() + } - if readErr != nil { - // We've read all we can - break - } - } + // For legacy NewBody usage (when not using PreReadBody), + // we still need to try to read remaining data + if !b.consumedOriginal { + _, _ = io.Copy(b.buffer, b.reader) } - // Now close the original reader - its implementation should handle any cleanup + // Close the original reader err := b.reader.Close() if !b.buffer.IsTruncated() { - // Mark as fully captured b.isFullyCaptured = true } diff --git a/collector/body_test.go b/collector/body_test.go index 1d560ff..8b579c5 100644 --- a/collector/body_test.go +++ b/collector/body_test.go @@ -65,3 +65,190 @@ func TestBody_ReadAfterClose(t *testing.T) { assert.Error(t, err) assert.Equal(t, collector.ErrBodyClosed, err) } + +// Test PreReadBody with small body that handler doesn't read +func TestPreReadBody_SmallBodyUnread(t *testing.T) { + data := "small test data" + reader := io.NopCloser(strings.NewReader(data)) + + body := collector.PreReadBody(reader, 1024) // limit > data size + + // Close without reading + err := body.Close() + require.NoError(t, err) + + // Captured data should be available + assert.Equal(t, data, body.String()) + assert.Equal(t, []byte(data), body.Bytes()) + assert.Equal(t, int64(len(data)), body.Size()) + assert.True(t, body.IsFullyCaptured()) + assert.False(t, body.IsTruncated()) +} + +// Test PreReadBody with small body that handler fully reads +func TestPreReadBody_SmallBodyRead(t *testing.T) { + data := "small test data for reading" + reader := io.NopCloser(strings.NewReader(data)) + + body := collector.PreReadBody(reader, 1024) // limit > data size + + // Handler reads the entire body + handlerData, err := io.ReadAll(body) + require.NoError(t, err) + assert.Equal(t, data, string(handlerData)) + + // Close the body + err = body.Close() + require.NoError(t, err) + + // Captured data should STILL be available after reading + closing + assert.Equal(t, data, body.String()) + assert.Equal(t, []byte(data), body.Bytes()) + assert.Equal(t, int64(len(data)), body.Size()) + assert.True(t, body.IsFullyCaptured()) + assert.False(t, body.IsTruncated()) +} + +// Test PreReadBody with large body that handler doesn't read +func TestPreReadBody_LargeBodyUnread(t *testing.T) { + data := strings.Repeat("x", 2000) // Large data + reader := io.NopCloser(strings.NewReader(data)) + + body := collector.PreReadBody(reader, 100) // limit < data size + + // Close without reading + err := body.Close() + require.NoError(t, err) + + // Only first 100 bytes should be captured + expectedCaptured := data[:100] + assert.Equal(t, expectedCaptured, body.String()) + assert.Equal(t, []byte(expectedCaptured), body.Bytes()) + assert.Equal(t, int64(100), body.Size()) + assert.False(t, body.IsFullyCaptured()) + assert.True(t, body.IsTruncated()) +} + +// Test PreReadBody with large body that handler fully reads - CRITICAL TEST +func TestPreReadBody_LargeBodyRead(t *testing.T) { + data := strings.Repeat("y", 2000) // Large data + reader := io.NopCloser(strings.NewReader(data)) + + body := collector.PreReadBody(reader, 100) // limit < data size + + // Handler reads the entire body (pre-read + remaining) + handlerData, err := io.ReadAll(body) + require.NoError(t, err) + assert.Equal(t, data, string(handlerData)) // Handler should get full data + + // Close the body + err = body.Close() + require.NoError(t, err) + + // Captured portion should STILL be available after reading + expectedCaptured := data[:100] + assert.Equal(t, expectedCaptured, body.String()) + assert.Equal(t, []byte(expectedCaptured), body.Bytes()) + assert.Equal(t, int64(100), body.Size()) + assert.False(t, body.IsFullyCaptured()) // Only partial capture + assert.True(t, body.IsTruncated()) +} + +// Test closing behavior - close without reading small body +func TestPreReadBody_CloseWithoutReading_SmallBody(t *testing.T) { + data := "test data" + reader := io.NopCloser(strings.NewReader(data)) + + body := collector.PreReadBody(reader, 1024) + + // Close immediately without any reading + err := body.Close() + require.NoError(t, err) + + // Captured data should be preserved + assert.Equal(t, data, body.String()) + assert.True(t, body.IsFullyCaptured()) +} + +// Test closing behavior - close without reading large body +func TestPreReadBody_CloseWithoutReading_LargeBody(t *testing.T) { + data := strings.Repeat("z", 500) + reader := io.NopCloser(strings.NewReader(data)) + + body := collector.PreReadBody(reader, 100) + + // Close immediately without any reading + err := body.Close() + require.NoError(t, err) + + // Captured portion should be preserved + expectedCaptured := data[:100] + assert.Equal(t, expectedCaptured, body.String()) + assert.True(t, body.IsTruncated()) +} + +// Test closing behavior - close after partial reading +func TestPreReadBody_CloseAfterPartialReading(t *testing.T) { + data := strings.Repeat("a", 500) + reader := io.NopCloser(strings.NewReader(data)) + + body := collector.PreReadBody(reader, 100) + + // Handler reads only part of the body + buf := make([]byte, 50) + n, err := body.Read(buf) + require.NoError(t, err) + assert.Equal(t, 50, n) + + // Close the body + err = body.Close() + require.NoError(t, err) + + // Captured data should still be available + expectedCaptured := data[:100] + assert.Equal(t, expectedCaptured, body.String()) + assert.True(t, body.IsTruncated()) +} + +// Test double close safety +func TestPreReadBody_DoubleClose(t *testing.T) { + data := "test data" + reader := io.NopCloser(strings.NewReader(data)) + + body := collector.PreReadBody(reader, 1024) + + // First close + err := body.Close() + require.NoError(t, err) + + // Second close should be safe + err = body.Close() + require.NoError(t, err) + + // Captured data should remain available + assert.Equal(t, data, body.String()) +} + +// Test read after close +func TestPreReadBody_ReadAfterClose(t *testing.T) { + data := "test data" + reader := io.NopCloser(strings.NewReader(data)) + + body := collector.PreReadBody(reader, 1024) + + // Close first + err := body.Close() + require.NoError(t, err) + + // Try to read after close + buf := make([]byte, 10) + _, err = body.Read(buf) + + // Should return ErrBodyClosed + assert.Error(t, err) + assert.Equal(t, collector.ErrBodyClosed, err) + + // Captured data should still be accessible + assert.Equal(t, data, body.String()) + assert.Equal(t, []byte(data), body.Bytes()) +} diff --git a/collector/http_client.go b/collector/http_client.go index 64e26ae..27a1499 100644 --- a/collector/http_client.go +++ b/collector/http_client.go @@ -144,8 +144,8 @@ func (t *httpClientTransport) RoundTrip(req *http.Request) (*http.Response, erro // Capture request body if present and configured to do so if req.Body != nil && t.collector.options.CaptureRequestBody { - // Wrap the body to capture it - body := NewBody(req.Body, t.collector.options.MaxBodySize) + // Pre-read the body to ensure capture + body := PreReadBody(req.Body, t.collector.options.MaxBodySize) // Store the body in the request record httpReq.RequestBody = body @@ -187,8 +187,8 @@ func (t *httpClientTransport) RoundTrip(req *http.Request) (*http.Response, erro // Create a copy of the response to read the body even if the client doesn't originalRespBody := resp.Body - // Wrap the body to capture it - body := NewBody(originalRespBody, t.collector.options.MaxBodySize) + // Pre-read the body to ensure capture even if client doesn't read it + body := PreReadBody(originalRespBody, t.collector.options.MaxBodySize) // Store the body in the request record httpReq.ResponseBody = body diff --git a/collector/http_server.go b/collector/http_server.go index e4f0278..7249836 100644 --- a/collector/http_server.go +++ b/collector/http_server.go @@ -140,8 +140,8 @@ func (c *HTTPServerCollector) Middleware(next http.Handler) http.Handler { // Save the original body originalBody := r.Body - // Create a body wrapper - requestBody = NewBody(originalBody, c.options.MaxBodySize) + // Pre-read the body to ensure capturing bodies even if the handler writes a large response (Go net/http will close the request body then) + requestBody = PreReadBody(originalBody, c.options.MaxBodySize) // Replace the request body with our wrapper r.Body = requestBody diff --git a/collector/http_server_test.go b/collector/http_server_test.go index e344961..6424419 100644 --- a/collector/http_server_test.go +++ b/collector/http_server_test.go @@ -639,6 +639,13 @@ func TestHTTPServerCollector_UnreadRequestBodyCapture(t *testing.T) { mux.HandleFunc("/exists", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("OK")) }) + mux.HandleFunc("/large-response", func(w http.ResponseWriter, r *http.Request) { + // Write a large response (>4KB to trigger Go's behavior of closing request body) + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + largeData := strings.Repeat("x", 8192) // 8KB + w.Write([]byte(largeData)) + }) // Wrap the handler with our collector wrappedHandler := serverCollector.Middleware(mux) @@ -675,6 +682,13 @@ func TestHTTPServerCollector_UnreadRequestBodyCapture(t *testing.T) { body: "handler=doesnt&read=this&but=should&capture=it", contentType: "application/x-www-form-urlencoded", }, + { + name: "large_response_with_unread_body", + path: "/large-response", + expectedStatus: http.StatusOK, + body: "important=data&that=should&be=captured&even=with&large=response", + contentType: "application/x-www-form-urlencoded", + }, } for _, tc := range testCases { diff --git a/collector/limited_buffer.go b/collector/limited_buffer.go index 76016a5..dd7335d 100644 --- a/collector/limited_buffer.go +++ b/collector/limited_buffer.go @@ -2,6 +2,7 @@ package collector import ( "bytes" + "io" ) // LimitedBuffer is a buffer that only writes up to a certain size @@ -54,6 +55,14 @@ func (b *LimitedBuffer) Reset() { b.truncated = false } +// ReadFrom is disabled to force io.CopyN to use our Write method with truncation logic +func (b *LimitedBuffer) ReadFrom(r io.Reader) (n int64, err error) { + // Force io.CopyN to use Write() method which has proper truncation logic + // by delegating to io.CopyBuffer + buf := make([]byte, 32*1024) // Use a reasonable buffer size + return io.CopyBuffer(struct{ io.Writer }{b}, r, buf) +} + // String returns the contents of the buffer as a string. // If the buffer was truncated, it will not include the truncated data. func (b *LimitedBuffer) String() string { diff --git a/collector/limited_buffer_test.go b/collector/limited_buffer_test.go new file mode 100644 index 0000000..4dedd60 --- /dev/null +++ b/collector/limited_buffer_test.go @@ -0,0 +1,188 @@ +package collector_test + +import ( + "errors" + "io" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/networkteam/devlog/collector" +) + +func TestLimitedBuffer_WriteWithinLimit(t *testing.T) { + buffer := collector.NewLimitedBuffer(100) + data := []byte("hello world") // 11 bytes, well within limit + + n, err := buffer.Write(data) + require.NoError(t, err) + assert.Equal(t, len(data), n) + assert.Equal(t, string(data), buffer.String()) + assert.Equal(t, len(data), buffer.Len()) + assert.False(t, buffer.IsTruncated()) +} + +func TestLimitedBuffer_WriteExceedsLimit(t *testing.T) { + buffer := collector.NewLimitedBuffer(10) + data := []byte("this is a very long string that exceeds the limit") // 50 bytes > 10 limit + + n, err := buffer.Write(data) + require.NoError(t, err) + assert.Equal(t, len(data), n) // Should return full length even when truncated + assert.Equal(t, "this is a ", buffer.String()) // Only first 10 bytes + assert.Equal(t, 10, buffer.Len()) + assert.True(t, buffer.IsTruncated()) +} + +func TestLimitedBuffer_WriteExactLimit(t *testing.T) { + buffer := collector.NewLimitedBuffer(10) + data := []byte("1234567890") // Exactly 10 bytes + + n, err := buffer.Write(data) + require.NoError(t, err) + assert.Equal(t, len(data), n) + assert.Equal(t, string(data), buffer.String()) + assert.Equal(t, 10, buffer.Len()) + assert.False(t, buffer.IsTruncated()) // Should NOT be truncated +} + +// Debug wrapper to log all Write calls +type debugLimitedBuffer struct { + *collector.LimitedBuffer + t *testing.T +} + +func (d *debugLimitedBuffer) Write(p []byte) (n int, err error) { + d.t.Logf("Write called with %d bytes, buffer len before: %d, truncated before: %v", + len(p), d.LimitedBuffer.Len(), d.LimitedBuffer.IsTruncated()) + n, err = d.LimitedBuffer.Write(p) + d.t.Logf("Write returned n=%d, err=%v, buffer len after: %d, truncated after: %v", + n, err, d.LimitedBuffer.Len(), d.LimitedBuffer.IsTruncated()) + return n, err +} + +// Check if io.CopyN is using ReadFrom instead of Write +func (d *debugLimitedBuffer) ReadFrom(r io.Reader) (n int64, err error) { + d.t.Logf("ReadFrom called! This bypasses Write() entirely") + // Don't call the embedded ReadFrom - force it to use Write instead + return 0, errors.New("ReadFrom disabled for debugging") +} + +// This is the critical test for our PreReadBody use case +func TestLimitedBuffer_CopyNWithLimitPlus1(t *testing.T) { + buffer := collector.NewLimitedBuffer(100) + data := strings.Repeat("x", 200) // 200 bytes of data + reader := strings.NewReader(data) + + t.Log("Starting io.CopyN with limit=100, copying 101 bytes") + + // This is exactly what PreReadBody does: copy limit+1 bytes + // When ReadFrom fails, io.CopyN should fall back to using Write() method + n, err := io.CopyN(buffer, reader, int64(101)) // limit+1 + + // The behavior when ReadFrom returns an error is that io.CopyN falls back to Write + // So we should get successful copy but with proper truncation + require.NoError(t, err) + assert.Equal(t, int64(101), n) // io.CopyN reports copying 101 bytes + + // Critical assertions: our Write method should have enforced the limit + t.Logf("Final: Buffer length: %d", buffer.Len()) + t.Logf("Final: Buffer truncated: %v", buffer.IsTruncated()) + t.Logf("Final: Buffer content length: %d", len(buffer.String())) + + // What we EXPECT should happen with the ReadFrom disabled: + assert.Equal(t, 100, buffer.Len(), "Buffer should contain only 100 bytes") + assert.True(t, buffer.IsTruncated(), "Buffer should be marked as truncated") + assert.Equal(t, strings.Repeat("x", 100), buffer.String(), "Buffer should contain first 100 chars") +} + +func TestLimitedBuffer_MultipleWrites(t *testing.T) { + buffer := collector.NewLimitedBuffer(10) + + // First write: 5 bytes + n, err := buffer.Write([]byte("hello")) + require.NoError(t, err) + assert.Equal(t, 5, n) + assert.False(t, buffer.IsTruncated()) + + // Second write: 3 bytes (total 8, still within limit) + n, err = buffer.Write([]byte(" wo")) + require.NoError(t, err) + assert.Equal(t, 3, n) + assert.False(t, buffer.IsTruncated()) + + // Third write: 5 bytes (would make total 13, exceeds limit of 10) + n, err = buffer.Write([]byte("rld!!")) + require.NoError(t, err) + assert.Equal(t, 5, n) // Should return full length + assert.True(t, buffer.IsTruncated()) // Should be truncated + assert.Equal(t, "hello worl", buffer.String()) // First 10 chars total (8 + 2 from "rld!!") + assert.Equal(t, 10, buffer.Len()) +} + +func TestLimitedBuffer_WriteAfterTruncation(t *testing.T) { + buffer := collector.NewLimitedBuffer(5) + + // First write exceeds limit + n, err := buffer.Write([]byte("hello world")) + require.NoError(t, err) + assert.Equal(t, 11, n) + assert.True(t, buffer.IsTruncated()) + assert.Equal(t, "hello", buffer.String()) + + // Subsequent writes should be ignored + n, err = buffer.Write([]byte(" more")) + require.NoError(t, err) + assert.Equal(t, 5, n) // Returns length as if written + assert.True(t, buffer.IsTruncated()) + assert.Equal(t, "hello", buffer.String()) // Unchanged + assert.Equal(t, 5, buffer.Len()) +} + +func TestLimitedBuffer_ZeroLimit(t *testing.T) { + buffer := collector.NewLimitedBuffer(0) + + n, err := buffer.Write([]byte("test")) + require.NoError(t, err) + assert.Equal(t, 4, n) + assert.True(t, buffer.IsTruncated()) + assert.Equal(t, "", buffer.String()) + assert.Equal(t, 0, buffer.Len()) +} + +func TestLimitedBuffer_EmptyWrite(t *testing.T) { + buffer := collector.NewLimitedBuffer(10) + + n, err := buffer.Write([]byte{}) + require.NoError(t, err) + assert.Equal(t, 0, n) + assert.False(t, buffer.IsTruncated()) + assert.Equal(t, "", buffer.String()) + assert.Equal(t, 0, buffer.Len()) +} + +func TestLimitedBuffer_Reset(t *testing.T) { + buffer := collector.NewLimitedBuffer(5) + + // Write data that exceeds limit + n, err := buffer.Write([]byte("hello world")) + require.NoError(t, err) + assert.Equal(t, 11, n) + assert.True(t, buffer.IsTruncated()) + assert.Equal(t, "hello", buffer.String()) + + // Reset should clear everything + buffer.Reset() + assert.False(t, buffer.IsTruncated()) + assert.Equal(t, "", buffer.String()) + assert.Equal(t, 0, buffer.Len()) + + // Should work normally after reset + n, err = buffer.Write([]byte("new")) + require.NoError(t, err) + assert.Equal(t, 3, n) + assert.False(t, buffer.IsTruncated()) + assert.Equal(t, "new", buffer.String()) +}