Skip to content

Commit 0cde61b

Browse files
committed
httputil: add response recorder
This implements a new-style response recording `http.ResponseWriter`. Signed-off-by: Hank Donnay <[email protected]>
1 parent c90a55f commit 0cde61b

File tree

2 files changed

+77
-0
lines changed

2 files changed

+77
-0
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package httputil
2+
3+
import "net/http"
4+
5+
// ResponseRecorder returns a ResponseWriter that records the HTTP status and
6+
// body length into the provided pointers, and returns another response writer
7+
// that understand the go 1.20 http `Unwrap` scheme.
8+
func ResponseRecorder(status *int, length *int64, w http.ResponseWriter) http.ResponseWriter {
9+
// Handle nils being passed, just to be nice.
10+
if length == nil {
11+
length = new(int64)
12+
}
13+
if status == nil {
14+
status = new(int)
15+
}
16+
return &responseRecord{
17+
ResponseWriter: w,
18+
status: status,
19+
length: length,
20+
}
21+
}
22+
23+
var _ http.ResponseWriter = (*responseRecord)(nil)
24+
25+
type responseRecord struct {
26+
http.ResponseWriter
27+
status *int
28+
length *int64
29+
writecall bool
30+
}
31+
32+
func (r *responseRecord) Unwrap() http.ResponseWriter {
33+
return r.ResponseWriter
34+
}
35+
36+
func (r *responseRecord) WriteHeader(c int) {
37+
if r.writecall {
38+
return
39+
}
40+
*r.status = c
41+
r.ResponseWriter.WriteHeader(c)
42+
r.writecall = true
43+
}
44+
45+
func (r *responseRecord) Write(b []byte) (int, error) {
46+
r.WriteHeader(http.StatusOK)
47+
n, err := r.ResponseWriter.Write(b)
48+
*r.length += int64(n)
49+
return n, err
50+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package httputil
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"testing"
7+
)
8+
9+
func TestResponseRecorder(t *testing.T) {
10+
var status int
11+
var length int64
12+
13+
rec := httptest.NewRecorder()
14+
w := ResponseRecorder(&status, &length, rec)
15+
16+
sz := 512
17+
if n, err := w.Write(make([]byte, sz)); err != nil || n != sz {
18+
t.Errorf("unexpected Write return: (%v, %v)", n, err)
19+
}
20+
t.Logf("wrote %d bytes, status %q", length, http.StatusText(status))
21+
if got, want := status, http.StatusOK; got != want {
22+
t.Errorf("bad status; got: %d, want: %d", got, want)
23+
}
24+
if got, want := length, int64(sz); got != want {
25+
t.Errorf("bad length; got: %d, want: %d", got, want)
26+
}
27+
}

0 commit comments

Comments
 (0)