Skip to content

Commit c621dac

Browse files
authored
simplify startup by using the new wg.Go feature (#65)
* simplify code by using the new wg.Go feature * add more test coverage
1 parent a21d98d commit c621dac

File tree

7 files changed

+437
-105
lines changed

7 files changed

+437
-105
lines changed

examples/http/main_test.go

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package main
22

33
import (
44
"context"
5+
"errors"
56
"io"
67
"log/slog"
78
"net/http"
@@ -40,8 +41,15 @@ func TestRunServer(t *testing.T) {
4041
errCh <- sv.Run()
4142
}()
4243

43-
// Give the server a moment to start
44-
time.Sleep(100 * time.Millisecond)
44+
// Wait for the server to be ready by checking if it responds to requests
45+
assert.Eventually(t, func() bool {
46+
resp, err := http.Get("http://localhost:8080/status")
47+
if err != nil {
48+
return false
49+
}
50+
defer func() { assert.NoError(t, resp.Body.Close()) }()
51+
return resp.StatusCode == http.StatusOK
52+
}, 2*time.Second, 50*time.Millisecond, "Server should become ready")
4553

4654
// Make a request to the server
4755
resp, err := http.Get("http://localhost:8080/status")
@@ -56,12 +64,14 @@ func TestRunServer(t *testing.T) {
5664
// Stop the supervisor
5765
sv.Shutdown()
5866

59-
// Check that Run() didn't return an error
67+
// Wait for Run() to complete and check the result
6068
select {
6169
case err := <-errCh:
62-
require.NoError(t, err, "Run() should not return an error")
63-
case <-time.After(100 * time.Millisecond):
64-
// This is expected - the server is still running
70+
if err != nil && !errors.Is(err, context.Canceled) {
71+
require.NoError(t, err, "Run() should not return an error")
72+
}
73+
case <-time.After(2 * time.Second):
74+
t.Fatal("Run() should have completed within timeout")
6575
}
6676
}
6777

supervisor/reload.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ func (p *PIDZero) ReloadAll() {
2525
// and calls the reload method on all reloadable services. This will also prevent
2626
// multiple reloads from happening concurrently.
2727
func (p *PIDZero) startReloadManager() {
28-
defer p.wg.Done()
2928
p.logger.Debug("Starting reload manager...")
3029

3130
// iterate all the runnables, and find the ones that are can send reload notifications

supervisor/shutdown.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import "sync"
55
// startShutdownManager starts goroutines to listen for shutdown notifications
66
// from any runnables that implement ShutdownSender. It blocks until the context is done.
77
func (p *PIDZero) startShutdownManager() {
8-
defer p.wg.Done()
98
p.logger.Debug("Starting shutdown manager...")
109

1110
var shutdownWg sync.WaitGroup

supervisor/shutdown_test.go

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,8 @@ func TestPIDZero_StartShutdownManager_TriggersShutdown(t *testing.T) {
3434
pidZero, err := New(WithContext(supervisorCtx), WithRunnables(mockService))
3535
require.NoError(t, err)
3636

37-
// Start the shutdown manager in a goroutine
38-
pidZero.wg.Add(1)
39-
go pidZero.startShutdownManager()
37+
// Start the shutdown manager using wg.Go
38+
pidZero.wg.Go(pidZero.startShutdownManager)
4039

4140
time.Sleep(200 * time.Millisecond)
4241

@@ -81,9 +80,8 @@ func TestPIDZero_StartShutdownManager_ContextCancel(t *testing.T) {
8180
pidZero, err := New(WithContext(ctx), WithRunnables(mockService))
8281
require.NoError(t, err)
8382

84-
// Start the shutdown manager in a goroutine
85-
pidZero.wg.Add(1)
86-
go pidZero.startShutdownManager()
83+
// Start the shutdown manager using wg.Go
84+
pidZero.wg.Go(pidZero.startShutdownManager)
8785

8886
// Give the manager a moment to start its internal listener goroutine
8987
time.Sleep(100 * time.Millisecond) // Increased sleep duration
@@ -128,9 +126,8 @@ func TestPIDZero_StartShutdownManager_NoSenders(t *testing.T) {
128126
pidZero, err := New(WithContext(ctx), WithRunnables(nonSenderRunnable))
129127
require.NoError(t, err)
130128

131-
// Start the shutdown manager in a goroutine
132-
pidZero.wg.Add(1)
133-
go pidZero.startShutdownManager()
129+
// Start the shutdown manager using wg.Go
130+
pidZero.wg.Go(pidZero.startShutdownManager)
134131

135132
// Give the manager a moment to start
136133
time.Sleep(50 * time.Millisecond)

supervisor/state.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,6 @@ func (p *PIDZero) broadcastState() {
163163
// Stateable runnable. It blocks until the context is done, coordinating state updates
164164
// from all state-emitting services.
165165
func (p *PIDZero) startStateMonitor() {
166-
defer p.wg.Done()
167166
p.logger.Debug("Starting state monitor...")
168167

169168
// Create a WaitGroup to track state monitoring goroutines

supervisor/supervisor.go

Lines changed: 20 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -190,34 +190,36 @@ func (p *PIDZero) Run() error {
190190
// Start a single reload manager if any runnable is reloadable
191191
for _, r := range p.runnables {
192192
if _, ok := r.(Reloadable); ok {
193-
p.wg.Add(1)
194-
go p.startReloadManager()
193+
p.wg.Go(p.startReloadManager)
195194
break
196195
}
197196
}
198197

199198
// Start a single state monitor if any runnable reports state
200199
for _, r := range p.runnables {
201200
if _, ok := r.(Stateable); ok {
202-
p.wg.Add(1)
203-
go p.startStateMonitor()
201+
p.wg.Go(p.startStateMonitor)
204202
break
205203
}
206204
}
207205

208206
// Start a single shutdown manager if any runnable can trigger shutdown
209207
for _, r := range p.runnables {
210208
if _, ok := r.(ShutdownSender); ok {
211-
p.wg.Add(1)
212-
go p.startShutdownManager()
209+
p.wg.Go(p.startShutdownManager)
213210
break
214211
}
215212
}
216213

217214
// Start each service in sequence
218215
for _, r := range p.runnables {
219-
p.wg.Add(1)
220-
go p.startRunnable(r) // start this runnable in a separate goroutine
216+
p.wg.Go(func() {
217+
err := p.startRunnable(r)
218+
if err != nil {
219+
p.logger.Error("Runnable exited with error", "runnable", r, "error", err)
220+
p.errorChan <- err
221+
}
222+
})
221223

222224
// if this Runnable implements the Stateable block here until IsRunning()
223225
if stateable, ok := r.(Stateable); ok {
@@ -263,9 +265,7 @@ func (p *PIDZero) blockUntilRunnableReady(r Stateable) error {
263265
)
264266
select {
265267
case err := <-p.errorChan:
266-
// error received from `startRunnable`- put it back in the channel for reap() to process it later
267-
p.errorChan <- err
268-
return fmt.Errorf("runnable failed to start: %w", err)
268+
return err
269269
case <-startupCtx.Done():
270270
return fmt.Errorf("timeout waiting for runnable to start: %w", startupCtx.Err())
271271
case <-p.ctx.Done():
@@ -362,9 +362,7 @@ func (p *PIDZero) listenForSignals() {
362362
}
363363

364364
// startRunnable starts a service and sends any errors to the error channel
365-
func (p *PIDZero) startRunnable(r Runnable) {
366-
defer p.wg.Done()
367-
365+
func (p *PIDZero) startRunnable(r Runnable) error {
368366
// Log the initial state if available
369367
if stateable, ok := r.(Stateable); ok {
370368
initialState := stateable.GetState()
@@ -378,31 +376,16 @@ func (p *PIDZero) startRunnable(r Runnable) {
378376

379377
// Run the runnable with the child context
380378
err := r.Run(runCtx)
381-
logger := p.logger.With("runnable", r)
382-
if err == nil {
383-
logger.Debug("Runnable completed without error")
384-
return
385-
}
386-
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
387-
// Filter out expected cancellation errors
388-
logger.Debug("Runnable stopped gracefully", "reason", err)
389-
return
390-
}
391-
392-
// Unexpected error - log and send to the errorChan
393-
logger = logger.With("error", err)
394-
if stateable, ok := r.(Stateable); ok {
395-
logger.Error("Service failed", "state", stateable.GetState())
396-
} else {
397-
logger.Error("Service failed")
398-
}
379+
if err != nil {
380+
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
381+
// Filter out expected cancellation errors
382+
p.logger.Debug("Runnable stopped gracefully", "reason", err, "runnable", r)
383+
return nil
384+
}
399385

400-
select {
401-
case p.errorChan <- fmt.Errorf("failed to start runnable: %w", err):
402-
// Error sent successfully
403-
default:
404-
logger.Warn("Unable to send error to errorChan (full or closed?)")
386+
return err
405387
}
388+
return nil
406389
}
407390

408391
// reap listens indefinitely for errors or OS signals and handles them appropriately.

0 commit comments

Comments
 (0)