Skip to content

Commit 28bf581

Browse files
authored
just a small refactor of api_test (#291)
This makes the test more hermetic by avoiding the need to set env vars, and it also avoids some unecessary duplication of test helper logic by leveraging some of aetest's underlying implementation. This change was originally part of #284, but I split it out because it's not compatible with v1's log flushing tests, and it would have added unecessary noise to that PR.
1 parent 6e2c50e commit 28bf581

File tree

3 files changed

+32
-65
lines changed

3 files changed

+32
-65
lines changed

v2/internal/api.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -258,9 +258,8 @@ func WithContext(parent context.Context, req *http.Request) context.Context {
258258
}
259259

260260
// RegisterTestRequest registers the HTTP request req for testing, such that
261-
// any API calls are sent to the provided URL. It returns a closure to delete
262-
// the registration.
263-
// It should only be used by aetest package.
261+
// any API calls are sent to the provided URL.
262+
// It should only be used by test code or test helpers like aetest.
264263
func RegisterTestRequest(req *http.Request, apiURL *url.URL, appID string) *http.Request {
265264
ctx := req.Context()
266265
ctx = withAPIHostOverride(ctx, apiURL.Hostname())

v2/internal/api_test.go

Lines changed: 27 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -141,51 +141,35 @@ func (f *fakeAPIHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
141141
})
142142
}
143143

144-
func setup() (f *fakeAPIHandler, c *aeContext, cleanup func()) {
144+
func makeTestRequest(apiURL *url.URL) *http.Request {
145+
req := &http.Request{
146+
Header: http.Header{
147+
ticketHeader: []string{"s3cr3t"},
148+
dapperHeader: []string{"trace-001"},
149+
},
150+
}
151+
return RegisterTestRequest(req, apiURL, "")
152+
}
153+
154+
func setup() (f *fakeAPIHandler, r *http.Request, cleanup func()) {
145155
f = &fakeAPIHandler{}
146156
srv := httptest.NewServer(f)
147-
u, err := url.Parse(srv.URL + apiPath)
148-
restoreAPIHost := restoreEnvVar("API_HOST")
149-
restoreAPIPort := restoreEnvVar("API_HOST")
150-
os.Setenv("API_HOST", u.Hostname())
151-
os.Setenv("API_PORT", u.Port())
157+
apiURL, err := url.Parse(srv.URL + apiPath)
152158
if err != nil {
153159
panic(fmt.Sprintf("url.Parse(%q): %v", srv.URL+apiPath, err))
154160
}
155-
return f, &aeContext{
156-
req: &http.Request{
157-
Header: http.Header{
158-
ticketHeader: []string{"s3cr3t"},
159-
dapperHeader: []string{"trace-001"},
160-
},
161-
},
162-
}, func() {
163-
restoreAPIHost()
164-
restoreAPIPort()
165-
srv.Close()
166-
}
167-
}
168-
169-
func restoreEnvVar(key string) (cleanup func()) {
170-
oldval, ok := os.LookupEnv(key)
171-
return func() {
172-
if ok {
173-
os.Setenv(key, oldval)
174-
} else {
175-
os.Unsetenv(key)
176-
}
177-
}
161+
return f, makeTestRequest(apiURL), srv.Close
178162
}
179163

180164
func TestAPICall(t *testing.T) {
181-
_, c, cleanup := setup()
165+
_, r, cleanup := setup()
182166
defer cleanup()
183167

184168
req := &basepb.StringProto{
185169
Value: proto.String("Doctor Who"),
186170
}
187171
res := &basepb.StringProto{}
188-
err := Call(toContext(c), "actordb", "LookupActor", req, res)
172+
err := Call(r.Context(), "actordb", "LookupActor", req, res)
189173
if err != nil {
190174
t.Fatalf("API call failed: %v", err)
191175
}
@@ -195,18 +179,16 @@ func TestAPICall(t *testing.T) {
195179
}
196180

197181
func TestAPICallTicketUnavailable(t *testing.T) {
198-
resetEnv := SetTestEnv()
199-
defer resetEnv()
200-
f, c, cleanup := setup()
182+
f, r, cleanup := setup()
201183
defer cleanup()
202184
f.allowMissingTicket = true
203185

204-
c.req.Header.Set(ticketHeader, "")
186+
r.Header.Set(ticketHeader, "")
205187
req := &basepb.StringProto{
206188
Value: proto.String("Doctor Who"),
207189
}
208190
res := &basepb.StringProto{}
209-
err := Call(toContext(c), "actordb", "LookupActor", req, res)
191+
err := Call(r.Context(), "actordb", "LookupActor", req, res)
210192
if err != nil {
211193
t.Fatalf("API call failed: %v", err)
212194
}
@@ -216,7 +198,7 @@ func TestAPICallTicketUnavailable(t *testing.T) {
216198
}
217199

218200
func TestAPICallRPCFailure(t *testing.T) {
219-
f, c, cleanup := setup()
201+
f, r, cleanup := setup()
220202
defer cleanup()
221203

222204
testCases := []struct {
@@ -230,7 +212,7 @@ func TestAPICallRPCFailure(t *testing.T) {
230212
}
231213
f.hang = make(chan int) // only for RunSlowly
232214
for _, tc := range testCases {
233-
ctx, _ := context.WithTimeout(toContext(c), 100*time.Millisecond)
215+
ctx, _ := context.WithTimeout(r.Context(), 100*time.Millisecond)
234216
err := Call(ctx, "errors", tc.method, &basepb.VoidProto{}, &basepb.VoidProto{})
235217
ce, ok := err.(*CallError)
236218
if !ok {
@@ -247,9 +229,7 @@ func TestAPICallRPCFailure(t *testing.T) {
247229
}
248230

249231
func TestAPICallDialFailure(t *testing.T) {
250-
// See what happens if the API host is unresponsive.
251-
// This should time out quickly, not hang forever.
252-
// We intentially don't set up the fakeAPIHandler for this test to cause the dail failure.
232+
// we intentially don't set up the fakeAPIHandler for this test to cause the dail failure
253233
start := time.Now()
254234
err := Call(context.Background(), "foo", "bar", &basepb.VoidProto{}, &basepb.VoidProto{})
255235
const max = 1 * time.Second
@@ -323,24 +303,18 @@ func TestAPICallAllocations(t *testing.T) {
323303
}
324304

325305
// Run the test API server in a subprocess so we aren't counting its allocations.
326-
cleanup := launchHelperProcess(t)
306+
apiURL, cleanup := launchHelperProcess(t)
327307
defer cleanup()
328-
c := &aeContext{
329-
req: &http.Request{
330-
Header: http.Header{
331-
ticketHeader: []string{"s3cr3t"},
332-
dapperHeader: []string{"trace-001"},
333-
},
334-
},
335-
}
308+
309+
r := makeTestRequest(apiURL)
336310

337311
req := &basepb.StringProto{
338312
Value: proto.String("Doctor Who"),
339313
}
340314
res := &basepb.StringProto{}
341315
var apiErr error
342316
avg := testing.AllocsPerRun(100, func() {
343-
ctx, _ := context.WithTimeout(toContext(c), 100*time.Millisecond)
317+
ctx, _ := context.WithTimeout(r.Context(), 100*time.Millisecond)
344318
if err := Call(ctx, "actordb", "LookupActor", req, res); err != nil && apiErr == nil {
345319
apiErr = err // get the first error only
346320
}
@@ -356,7 +330,7 @@ func TestAPICallAllocations(t *testing.T) {
356330
}
357331
}
358332

359-
func launchHelperProcess(t *testing.T) (cleanup func()) {
333+
func launchHelperProcess(t *testing.T) (apiURL *url.URL, cleanup func()) {
360334
cmd := exec.Command(os.Args[0], "-test.run=TestHelperProcess")
361335
cmd.Env = []string{"GO_WANT_HELPER_PROCESS=1"}
362336
stdin, err := cmd.StdinPipe()
@@ -391,13 +365,7 @@ func launchHelperProcess(t *testing.T) (cleanup func()) {
391365
t.Fatal("Helper process never reported")
392366
}
393367

394-
restoreAPIHost := restoreEnvVar("API_HOST")
395-
restoreAPIPort := restoreEnvVar("API_HOST")
396-
os.Setenv("API_HOST", u.Hostname())
397-
os.Setenv("API_PORT", u.Port())
398-
return func() {
399-
restoreAPIHost()
400-
restoreAPIPort()
368+
return u, func() {
401369
stdin.Close()
402370
if err := cmd.Wait(); err != nil {
403371
t.Errorf("Helper process did not exit cleanly: %v", err)

v2/internal/net_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func TestDialLimit(t *testing.T) {
2626
}
2727
}()
2828

29-
f, c, cleanup := setup() // setup is in api_test.go
29+
f, r, cleanup := setup() // setup is in api_test.go
3030
defer cleanup()
3131
f.hang = make(chan int)
3232

@@ -37,12 +37,12 @@ func TestDialLimit(t *testing.T) {
3737
for i := 0; i < 2; i++ {
3838
go func() {
3939
defer wg.Done()
40-
Call(toContext(c), "errors", "RunSlowly", &basepb.VoidProto{}, &basepb.VoidProto{})
40+
Call(r.Context(), "errors", "RunSlowly", &basepb.VoidProto{}, &basepb.VoidProto{})
4141
}()
4242
}
4343
time.Sleep(50 * time.Millisecond) // let those two RPCs start
4444

45-
ctx, _ := context.WithTimeout(toContext(c), 50*time.Millisecond)
45+
ctx, _ := context.WithTimeout(r.Context(), 50*time.Millisecond)
4646
err := Call(ctx, "errors", "Non200", &basepb.VoidProto{}, &basepb.VoidProto{})
4747
if err != errTimeout {
4848
t.Errorf("Non200 RPC returned with err %v, want errTimeout", err)

0 commit comments

Comments
 (0)