Skip to content

Commit e59ddc1

Browse files
committed
expose internal.handleHTTP as a standard http middleware
1 parent b48684e commit e59ddc1

File tree

5 files changed

+92
-79
lines changed

5 files changed

+92
-79
lines changed

appengine.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ func Main() {
5454
internal.Main()
5555
}
5656

57+
// Middleware wraps an http handler so that it can make GAE API calls
58+
var Middleware func(http.Handler) http.Handler = internal.Middleware
59+
5760
// IsDevAppServer reports whether the App Engine app is running in the
5861
// development App Server.
5962
func IsDevAppServer() bool {

internal/api.go

Lines changed: 82 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -87,88 +87,98 @@ func apiURL() *url.URL {
8787
}
8888
}
8989

90-
func handleHTTP(w http.ResponseWriter, r *http.Request) {
91-
c := &context{
92-
req: r,
93-
outHeader: w.Header(),
94-
apiURL: apiURL(),
95-
}
96-
r = r.WithContext(withContext(r.Context(), c))
97-
c.req = r
98-
99-
stopFlushing := make(chan int)
90+
// Middleware wraps an http handler so that it can make GAE API calls
91+
func Middleware(next http.Handler) http.Handler {
92+
return handleHTTPMiddleware(executeRequestSafelyMiddleware(next))
93+
}
10094

101-
// Patch up RemoteAddr so it looks reasonable.
102-
if addr := r.Header.Get(userIPHeader); addr != "" {
103-
r.RemoteAddr = addr
104-
} else if addr = r.Header.Get(remoteAddrHeader); addr != "" {
105-
r.RemoteAddr = addr
106-
} else {
107-
// Should not normally reach here, but pick a sensible default anyway.
108-
r.RemoteAddr = "127.0.0.1"
109-
}
110-
// The address in the headers will most likely be of these forms:
111-
// 123.123.123.123
112-
// 2001:db8::1
113-
// net/http.Request.RemoteAddr is specified to be in "IP:port" form.
114-
if _, _, err := net.SplitHostPort(r.RemoteAddr); err != nil {
115-
// Assume the remote address is only a host; add a default port.
116-
r.RemoteAddr = net.JoinHostPort(r.RemoteAddr, "80")
117-
}
95+
func handleHTTPMiddleware(next http.Handler) http.Handler {
96+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
97+
c := &context{
98+
req: r,
99+
outHeader: w.Header(),
100+
apiURL: apiURL(),
101+
}
102+
r = r.WithContext(withContext(r.Context(), c))
103+
c.req = r
104+
105+
stopFlushing := make(chan int)
106+
107+
// Patch up RemoteAddr so it looks reasonable.
108+
if addr := r.Header.Get(userIPHeader); addr != "" {
109+
r.RemoteAddr = addr
110+
} else if addr = r.Header.Get(remoteAddrHeader); addr != "" {
111+
r.RemoteAddr = addr
112+
} else {
113+
// Should not normally reach here, but pick a sensible default anyway.
114+
r.RemoteAddr = "127.0.0.1"
115+
}
116+
// The address in the headers will most likely be of these forms:
117+
// 123.123.123.123
118+
// 2001:db8::1
119+
// net/http.Request.RemoteAddr is specified to be in "IP:port" form.
120+
if _, _, err := net.SplitHostPort(r.RemoteAddr); err != nil {
121+
// Assume the remote address is only a host; add a default port.
122+
r.RemoteAddr = net.JoinHostPort(r.RemoteAddr, "80")
123+
}
118124

119-
if logToLogservice() {
120-
// Start goroutine responsible for flushing app logs.
121-
// This is done after adding c to ctx.m (and stopped before removing it)
122-
// because flushing logs requires making an API call.
123-
go c.logFlusher(stopFlushing)
124-
}
125+
if logToLogservice() {
126+
// Start goroutine responsible for flushing app logs.
127+
// This is done after adding c to ctx.m (and stopped before removing it)
128+
// because flushing logs requires making an API call.
129+
go c.logFlusher(stopFlushing)
130+
}
125131

126-
executeRequestSafely(c, r)
127-
c.outHeader = nil // make sure header changes aren't respected any more
132+
next.ServeHTTP(c, r)
133+
c.outHeader = nil // make sure header changes aren't respected any more
128134

129-
flushed := make(chan struct{})
130-
if logToLogservice() {
131-
stopFlushing <- 1 // any logging beyond this point will be dropped
135+
flushed := make(chan struct{})
136+
if logToLogservice() {
137+
stopFlushing <- 1 // any logging beyond this point will be dropped
132138

133-
// Flush any pending logs asynchronously.
134-
c.pendingLogs.Lock()
135-
flushes := c.pendingLogs.flushes
136-
if len(c.pendingLogs.lines) > 0 {
137-
flushes++
139+
// Flush any pending logs asynchronously.
140+
c.pendingLogs.Lock()
141+
flushes := c.pendingLogs.flushes
142+
if len(c.pendingLogs.lines) > 0 {
143+
flushes++
144+
}
145+
c.pendingLogs.Unlock()
146+
go func() {
147+
defer close(flushed)
148+
// Force a log flush, because with very short requests we
149+
// may not ever flush logs.
150+
c.flushLog(true)
151+
}()
152+
w.Header().Set(logFlushHeader, strconv.Itoa(flushes))
138153
}
139-
c.pendingLogs.Unlock()
140-
go func() {
141-
defer close(flushed)
142-
// Force a log flush, because with very short requests we
143-
// may not ever flush logs.
144-
c.flushLog(true)
145-
}()
146-
w.Header().Set(logFlushHeader, strconv.Itoa(flushes))
147-
}
148154

149-
// Avoid nil Write call if c.Write is never called.
150-
if c.outCode != 0 {
151-
w.WriteHeader(c.outCode)
152-
}
153-
if c.outBody != nil {
154-
w.Write(c.outBody)
155-
}
156-
if logToLogservice() {
157-
// Wait for the last flush to complete before returning,
158-
// otherwise the security ticket will not be valid.
159-
<-flushed
160-
}
155+
// Avoid nil Write call if c.Write is never called.
156+
if c.outCode != 0 {
157+
w.WriteHeader(c.outCode)
158+
}
159+
if c.outBody != nil {
160+
w.Write(c.outBody)
161+
}
162+
if logToLogservice() {
163+
// Wait for the last flush to complete before returning,
164+
// otherwise the security ticket will not be valid.
165+
<-flushed
166+
}
167+
})
161168
}
162169

163-
func executeRequestSafely(c *context, r *http.Request) {
164-
defer func() {
165-
if x := recover(); x != nil {
166-
logf(c, 4, "%s", renderPanic(x)) // 4 == critical
167-
c.outCode = 500
168-
}
169-
}()
170+
func executeRequestSafelyMiddleware(next http.Handler) http.Handler {
171+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
172+
defer func() {
173+
if x := recover(); x != nil {
174+
c := w.(*context)
175+
logf(c, 4, "%s", renderPanic(x)) // 4 == critical
176+
c.outCode = 500
177+
}
178+
}()
170179

171-
http.DefaultServeMux.ServeHTTP(c, r)
180+
next.ServeHTTP(w, r)
181+
})
172182
}
173183

174184
func renderPanic(x interface{}) string {

internal/api_classic.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ func Call(ctx netcontext.Context, service, method string, in, out proto.Message)
144144
return err
145145
}
146146

147-
func handleHTTP(w http.ResponseWriter, r *http.Request) {
148-
panic("handleHTTP called; this should be impossible")
147+
func Middleware(next http.Handler) http.Handler {
148+
panic("Middleware called; this should be impossible")
149149
}
150150

151151
func logf(c appengine.Context, level int64, format string, args ...interface{}) {

internal/api_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ func TestDelayedLogFlushing(t *testing.T) {
302302
handled := make(chan struct{})
303303
go func() {
304304
defer close(handled)
305-
handleHTTP(w, r)
305+
Middleware(http.DefaultServeMux).ServeHTTP(w, r)
306306
}()
307307
// Check that the log flush eventually comes in.
308308
time.Sleep(1200 * time.Millisecond)
@@ -360,7 +360,7 @@ func TestLogFlushing(t *testing.T) {
360360
}
361361
w := httptest.NewRecorder()
362362

363-
handleHTTP(w, r)
363+
Middleware(http.DefaultServeMux).ServeHTTP(w, r)
364364
const hdr = "X-AppEngine-Log-Flush-Count"
365365
if got := w.HeaderMap.Get(hdr); got != tc.wantHeader {
366366
t.Errorf("%s header = %q, want %q", hdr, got, tc.wantHeader)
@@ -403,7 +403,7 @@ func TestRemoteAddr(t *testing.T) {
403403
Header: tc.headers,
404404
Body: ioutil.NopCloser(bytes.NewReader(nil)),
405405
}
406-
handleHTTP(httptest.NewRecorder(), r)
406+
Middleware(http.DefaultServeMux).ServeHTTP(httptest.NewRecorder(), r)
407407
if addr != tc.addr {
408408
t.Errorf("Header %v, got %q, want %q", tc.headers, addr, tc.addr)
409409
}
@@ -420,7 +420,7 @@ func TestPanickingHandler(t *testing.T) {
420420
Body: ioutil.NopCloser(bytes.NewReader(nil)),
421421
}
422422
rec := httptest.NewRecorder()
423-
handleHTTP(rec, r)
423+
Middleware(http.DefaultServeMux).ServeHTTP(rec, r)
424424
if rec.Code != 500 {
425425
t.Errorf("Panicking handler returned HTTP %d, want HTTP %d", rec.Code, 500)
426426
}

internal/main_vm.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ func Main() {
2929
if IsDevAppServer() {
3030
host = "127.0.0.1"
3131
}
32-
if err := http.ListenAndServe(host+":"+port, http.HandlerFunc(handleHTTP)); err != nil {
32+
if err := http.ListenAndServe(host+":"+port, Middleware(http.DefaultServeMux)); err != nil {
3333
log.Fatalf("http.ListenAndServe: %v", err)
3434
}
3535
}

0 commit comments

Comments
 (0)